# coding=utf-8 # Copyright 2018 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for serving tensor2tensor.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import base64 from googleapiclient import discovery from grpc.beta import implementations from tensor2tensor import problems as problems_lib # pylint: disable=unused-import from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import cloud_tpu as cloud import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2 def _make_example(input_ids, feature_name="inputs"): features = { feature_name: tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)) } return tf.train.Example(features=tf.train.Features(feature=features)) def _create_stub(server): host, port = server.split(":") channel = implementations.insecure_channel(host, int(port)) # TODO(bgb): Migrate to GA API. return prediction_service_pb2.beta_create_PredictionService_stub(channel) def _encode(inputs, encoder, add_eos=True): input_ids = encoder.encode(inputs) if add_eos: input_ids.append(text_encoder.EOS_ID) return input_ids def _decode(output_ids, output_decoder): return output_decoder.decode(output_ids, strip_extraneous=True) def make_grpc_request_fn(servable_name, server, timeout_secs): """Wraps function to make grpc requests with runtime args.""" stub = _create_stub(server) def _make_grpc_request(examples): """Builds and sends request to TensorFlow model server.""" request = predict_pb2.PredictRequest() request.model_spec.name = servable_name request.inputs["input"].CopyFrom( tf.contrib.util.make_tensor_proto( [ex.SerializeToString() for ex in examples], shape=[len(examples)])) response = stub.Predict(request, timeout_secs) outputs = tf.make_ndarray(response.outputs["outputs"]) scores = tf.make_ndarray(response.outputs["scores"]) assert len(outputs) == len(scores) return [{ "outputs": outputs[i], "scores": scores[i] } for i in range(len(outputs))] return _make_grpc_request def make_cloud_mlengine_request_fn(credentials, model_name, version): """Wraps function to make CloudML Engine requests with runtime args.""" def _make_cloud_mlengine_request(examples): """Builds and sends requests to Cloud ML Engine.""" api = discovery.build("ml", "v1", credentials=credentials) parent = "projects/%s/models/%s/versions/%s" % (cloud.default_project(), model_name, version) input_data = { "instances": [{ "input": { "b64": base64.b64encode(ex.SerializeToString()) } } for ex in examples] } prediction = api.projects().predict(body=input_data, name=parent).execute() return prediction["predictions"] return _make_cloud_mlengine_request def predict(inputs_list, problem, request_fn): """Encodes inputs, makes request to deployed TF model, and decodes outputs.""" assert isinstance(inputs_list, list) fname = "inputs" if problem.has_inputs else "targets" input_encoder = problem.feature_info[fname].encoder input_ids_list = [ _encode(inputs, input_encoder, add_eos=problem.has_inputs) for inputs in inputs_list ] examples = [_make_example(input_ids, fname) for input_ids in input_ids_list] predictions = request_fn(examples) output_decoder = problem.feature_info["targets"].encoder outputs = [ (_decode(prediction["outputs"], output_decoder), prediction["scores"]) for prediction in predictions ] return outputs