#!/usr/bin/python # -*- coding:utf-8 -*- ''' @author: xichen ding @date: 2016-11-15 @rev: 2017-11-01 ''' from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals # compatible with python3 unicode coding import sys, os import tensorflow as tf import numpy as np import glob # adding pos submodule to sys.path, compatible with py3 absolute_import pkg_path = os.path.dirname(os.path.abspath(__file__)) # .../deepnlp/ sys.path.append(pkg_path) from pos import pos_model as pos_model from pos import reader as pos_reader from model_util import get_model_var_scope from model_util import get_config, load_config from model_util import _pos_scope_name class ModelLoader(object): def __init__(self, name, data_path, ckpt_path, conf_path): self.name = name # model name self.data_path = data_path self.ckpt_path = ckpt_path # the path of the ckpt file, e.g. ./ckpt/zh/pos.ckpt self.model_config_path = conf_path print("NOTICE: Starting new Tensorflow session...") self.session = tf.Session() print("NOTICE: Initializing pos_tagger class...") self.model = None self.var_scope = _pos_scope_name self._init_pos_model(self.session) # Initialization model def predict(self, words): ''' Coding: utf-8 for Chinese Characters Return tuples of [(word, tag),...] ''' tagging = self._predict_pos_tags(self.session, self.model, words, self.data_path) return tagging ## Initialize and Instance, Define Config Parameters for POS Tagger def _init_pos_model(self, session): """Create POS Tagger model and initialize with random or load parameters in session.""" # initilize config config_dict = load_config(self.model_config_path) config = get_config(config_dict, self.name) config.batch_size = 1 config.num_steps = 1 # iterator one token per time model_var_scope = get_model_var_scope(self.var_scope, self.name) print ("NOTICE: Input POS Model Var Scope Name '%s'" % model_var_scope) # Check if self.model already exist if self.model is None: with tf.variable_scope(model_var_scope, tf.AUTO_REUSE): self.model = pos_model.POSTagger(is_training=False, config=config) # save object after is_training # Load Specific .data* ckpt file if len(glob.glob(self.ckpt_path + '.data*')) > 0: # file exist with pattern: 'pos.ckpt.data*' print("NOTICE: Loading model parameters from %s" % self.ckpt_path) all_vars = tf.global_variables() model_vars = [k for k in all_vars if model_var_scope in k.name.split("/")] tf.train.Saver(model_vars).restore(session, self.ckpt_path) else: print("NOTICE: Model not found, Try to run method: deepnlp.download(module='pos', name='%s')" % self.name) print("NOTICE: Created with fresh parameters.") session.run(tf.global_variables_initializer()) def _predict_pos_tags(self, session, model, words, data_path): ''' Define prediction function of POS Tagging return tuples [(word, tag)] ''' word_data = pos_reader.sentence_to_word_ids(data_path, words) tag_data = [0]*len(word_data) state = session.run(model.initial_state) predict_id =[] for step, (x, y) in enumerate(pos_reader.iterator(word_data, tag_data, model.batch_size, model.num_steps)): #print ("Current Step" + str(step)) fetches = [model.cost, model.final_state, model.logits] feed_dict = {} feed_dict[model.input_data] = x feed_dict[model.targets] = y for i, (c, h) in enumerate(model.initial_state): feed_dict[c] = state[i].c feed_dict[h] = state[i].h _, _, logits = session.run(fetches, feed_dict) predict_id.append(int(np.argmax(logits))) #print (logits) predict_tag = pos_reader.word_ids_to_sentence(data_path, predict_id) return zip(words, predict_tag) def load_model(name = 'zh'): ''' data_path e.g.: ./deepnlp/pos/data/zh ckpt_path e.g.: ./deepnlp/pos/ckpt/zh/pos.ckpt ckpt_file e.g.: ./deepnlp/pos/ckpt/zh/pos.ckpt.data-00000-of-00001 ''' try: from deepnlp.model_util import registered_models except Exception as e: print (e) registered_model_list = registered_models[0]['pos'] if name not in registered_model_list: print ("WARNING: Input model name '%s' is not registered..." % name) print ("WARNING: Please use deepnlp.register_model('%s', '%s') ..." % ("pos", name)) return None data_path = os.path.join(pkg_path, "pos/data", name) # POS vocabulary data path ckpt_path = os.path.join(pkg_path, "pos/ckpt", name, "pos.ckpt") # POS model checkpoint path conf_path = os.path.join(pkg_path, "pos/data", "models.conf") return ModelLoader(name, data_path, ckpt_path, conf_path)