#
# Copyright 2018-2019 IBM Corp. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from maxfw.model import MAXModelWrapper

import math
import logging

import tensorflow as tf

from core import configuration
from core import inference_wrapper
from core.inference_utils import vocabulary
from core.inference_utils import caption_generator

from config import DEFAULT_MODEL_PATH, VOCAB_FILE

logger = logging.getLogger()


class ModelWrapper(MAXModelWrapper):

    def __init__(self, path=DEFAULT_MODEL_PATH):
        # TODO Replace this part with SavedModel
        g = tf.Graph()
        with g.as_default():
            model = inference_wrapper.InferenceWrapper()
            restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
                                                       path)
        g.finalize()
        self.model = model
        sess = tf.Session(graph=g)
        # Load the model from checkpoint.
        restore_fn(sess)
        self.sess = sess

    def _predict(self, image_data):
        # Create the vocabulary.
        vocab = vocabulary.Vocabulary(VOCAB_FILE)

        # Prepare the caption generator. Here we are implicitly using the default
        # beam search parameters. See caption_generator.py for a description of the
        # available beam search parameters.
        generator = caption_generator.CaptionGenerator(self.model, vocab)

        captions = generator.beam_search(self.sess, image_data)

        results = []
        for i, caption in enumerate(captions):
            # Ignore begin and end words.
            sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
            sentence = " ".join(sentence)
            # print("  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
            results.append((i, sentence, math.exp(caption.logprob)))

        return results