# coding:utf-8 # Restore model from pb file and do prediction import sys import codecs import pickle import numpy as np from tensorflow.contrib import predictor from pathlib import Path PROJECT_PATH = Path(__file__).absolute().parent sys.path.insert(0, str(PROJECT_PATH)) from utils.log import log_error as _error from load_data import convert_to_idx, create_mask_for_lm, create_mask_for_seq, create_mask_for_bi with codecs.open('data/vocab_idx.pt', 'rb') as file, \ codecs.open('data/idx_vocab.pt', 'rb') as file_2: vocab_idx = pickle.load(file) idx_vocab = pickle.load(file_2) class bertPredict(object): def __init__(self, pb_path): subdirs = [x for x in Path(pb_path).iterdir() if x.is_dir() and 'temp' not in str(x)] latest = str(sorted(subdirs)[-1]) self.predict_fn = predictor.from_saved_model(latest) self.vocab_idx, self.idx_vocab = vocab_idx, idx_vocab def predict(self, input_ids, max_length): input_ids = convert_to_idx(input_ids) input_ids, input_mask, masked_lm_positions = self._process_input(input_ids, max_length) input_ids = np.array(input_ids, dtype=np.int32) input_mask = np.array(input_mask, dtype=np.int32) masked_lm_positions = np.array(masked_lm_positions, dtype=np.int32) # input_ids[0][5] = 872 # input_ids[0][6] = 1962 # input_ids[0][7] = 1557 # input_ids[0][8] = 511 # input_ids[0][9] = 2 # [872, 1962, 1557, 511, 2] result = self.predict_fn( {'input_ids': input_ids, 'input_mask': input_mask, 'masked_lm_positions': masked_lm_positions}) return result def _process_input(self, input_ids, max_length): assert len(input_ids) < max_length, _error('Input length is larger than the maximum length') question_length = len(input_ids) input_ids += [vocab_idx['<mask>'] for _ in range(max_length - question_length)] # input_ids[2] = 330 # input_ids[3] = 1470 # input_ids[4] = 1048 # input_ids[5] = 116 input_mask = [1 for _ in range(question_length)] + [0 for _ in range(max_length - question_length)] input_mask = create_mask_for_seq(input_mask, question_length, max_length - question_length) # input_mask = [] # for _ in range(max_length): # temp = [1 for _ in range(question_length)] + [0 for _ in range(max_length - question_length)] # input_mask.append(temp) masked_lm_positions = [question_length + idx for idx in range(max_length - question_length)] return [input_ids], [input_mask], [masked_lm_positions] # def _load_vocab(self, vocab_path): # with codecs.open(vocab_path, 'r', 'utf-8') as file: # vocab_idx = {} # idx_vocab = {} # for idx, vocab in enumerate(file): # vocab = vocab.strip() # vocab_idx[vocab] = idx # idx_vocab[idx] = vocab # print('r', vocab_idx['你']) # print('r', idx_vocab[871]) # return vocab_idx, idx_vocab if __name__ == '__main__': bert = bertPredict('models_to_deploy') test_tensence = '<s> 你 好 <\s>' # print(result['output']) # # print(bert.vocab_idx['<\s>']) # c = 0 # while c < 19: # result = bert.predict(test_tensence, max_length=20) # idx = result['output'][0] # test_tensence = test_tensence + ' ' + idx_vocab[result['output'][0]] # print(test_tensence) # print(idx_vocab[idx]) # if idx == vocab_idx['<\s>']: # break # c +=1 # # result = bert.predict(test_tensence, max_length=20) # # c += 1 # # print(test_tensence) result = bert.predict(test_tensence, max_length=20) for idx in result['output']: if idx == bert.vocab_idx['<\s>']: break else: print(bert.idx_vocab[idx])