# # Copyright 2018-2019 IBM Corp. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from maxfw.model import MAXModelWrapper import numpy as np import re import tensorflow as tf from tensorflow.python.saved_model import tag_constants import logging from core.utils import get_processing_word, load_vocab, pad_sequences from config import DEFAULT_MODEL_PATH, MODEL_META_DATA as model_meta logger = logging.getLogger() class ModelWrapper(MAXModelWrapper): MODEL_META_DATA = model_meta pat = re.compile(r'(\W+)') """Model wrapper for TensorFlow models in SavedModel format""" def __init__(self, path=DEFAULT_MODEL_PATH): logger.info('Loading model from: {}...'.format(path)) # load assets first to enable model definition self._load_assets(path) # Loading the tf SavedModel self.graph = tf.Graph() self.sess = tf.Session(graph=self.graph) tf.saved_model.loader.load(self.sess, [tag_constants.SERVING], DEFAULT_MODEL_PATH) self.word_ids_tensor = self.sess.graph.get_tensor_by_name('word_input:0') self.char_ids_tensor = self.sess.graph.get_tensor_by_name('char_input:0') self.output_tensor = self.sess.graph.get_tensor_by_name('predict_output/truediv:0') def _load_assets(self, path): vocab_tags = load_vocab(path + "/tags.txt") vocab_chars = load_vocab(path + "/chars.txt") vocab_words = load_vocab(path + "/words.txt") self.proc_fn = get_processing_word(vocab_words, vocab_chars, lowercase=True, chars=True) self.id_to_tag = {idx: v for v, idx in vocab_tags.items()} self.n_words = len(vocab_words) self.n_char = len(vocab_chars) n_tags = len(vocab_tags) self.pad_tag = n_tags self.n_labels = n_tags + 1 def _pre_process(self, x): words_raw = re.split(self.pat, x) words_raw = [w.strip() for w in words_raw] # strip whitespace words_raw = [w for w in words_raw if w] # keep only non-empty terms, keeping raw punctuation words = [self.proc_fn(w) for w in words_raw] char_ids, word_ids = zip(*words) word_ids, _ = pad_sequences([word_ids], pad_tok=self.pad_tag) char_ids, _ = pad_sequences([char_ids], pad_tok=self.pad_tag, nlevels=2) word_ids_arr = np.array(word_ids) char_ids_arr = np.array(char_ids) return words_raw, word_ids_arr, char_ids_arr def _post_process(self, x): return [self.id_to_tag[i] for i in x.ravel()] def _predict(self, word_ids_arr, char_ids_arr): pred = self.sess.run(self.output_tensor, feed_dict={ self.word_ids_tensor: word_ids_arr, self.char_ids_tensor: char_ids_arr }) return np.argmax(pred, -1) def predict(self, x): words, word_ids_arr, char_ids_arr = self._pre_process(x) labels_pred_arr = self._predict(word_ids_arr, char_ids_arr) labels_pred = self._post_process(labels_pred_arr) return labels_pred, words