# # Copyright 2018-2019 IBM Corp. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from maxfw.model import MAXModelWrapper from keras.backend import clear_session from keras import models import tensorflow as tf import numpy as np import json import logging from config import DEFAULT_MODEL_PATH, DEFAULT_MODEL_FILE, SEED_TEXT_LEN, MODEL_META_DATA as model_meta logger = logging.getLogger() # (Fixed) length of seed text that can serve as input to the generative model _SEED_TEXT_LEN = 256 class ModelWrapper(MAXModelWrapper): MODEL_META_DATA = model_meta """Model wrapper for Keras models""" def __init__(self, path=DEFAULT_MODEL_PATH, model_file=DEFAULT_MODEL_FILE): logger.info('Loading model from: {}...'.format(path)) model_path = '{}/{}'.format(path, model_file) clear_session() self.graph = tf.Graph() with self.graph.as_default(): self.model = models.load_model(model_path) logger.info('Loaded model: {}'.format(self.model.name)) self._load_assets(path) def _load_assets(self, path): with open('{}/char_indices.txt'.format(path)) as f: self.char_indices = json.loads(f.read()) self.chars = sorted(self.char_indices.keys()) self.num_chars = len(self.chars) with open('{}/indices_char.txt'.format(path)) as f: self.indices_char = json.loads(f.read()) def _sample(self, preds, temperature=.6): # helper function to sample an index from a probability array preds = np.asarray(preds).astype('float64') preds = np.log(preds) / temperature exp_preds = np.exp(preds) preds = exp_preds / np.sum(exp_preds) probas = np.random.multinomial(1, preds, 1) return np.argmax(probas) def _predict(self, args_dict): ''' Generate text based on seed text. Args: sentence: Input seed text to kick off generation. gen_chars: How many characters of text to generate. Returns generated text ''' # The model was trained on lowercase text only, and there is no # provision in the model itself for handling characters that are # out of vocabulary. # To compensate, turn everything into lowercase, then check for # out-of-vocab characters in the result. sentence, gen_chars = args_dict["sentence"], args_dict["gen_chars"] sentence = sentence.lower() for t, char in enumerate(sentence): if char not in self.char_indices: print("Bad char {} at position {}".format(char, t)) raise ValueError( "Unexpected character '{}' at position {}. " "Only lowercase ASCII characters, spaces, " "and basic punctuation are supported.".format(char, t)) # The text passed into the model must be exactly SEED_TEXT_LEN # characters long, or the model will crash. Pad or truncate. if len(sentence) > SEED_TEXT_LEN: sentence = sentence[:SEED_TEXT_LEN] else: sentence = sentence.rjust(SEED_TEXT_LEN) generated = '' with self.graph.as_default(): for i in range(gen_chars): x = np.zeros((1, SEED_TEXT_LEN, self.num_chars)) for t, char in enumerate(sentence): x[0, t, self.char_indices[char]] = 1. preds = self.model.predict(x, verbose=0)[0] next_index = self._sample(preds) next_char = self.indices_char[str(next_index)] generated += next_char sentence = sentence[1:] + next_char return generated