* Copyright 2017 the original author or authors.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *       http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.springframework.cloud.stream.app.tensorflow.processor;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.cloud.stream.messaging.Processor;
import org.springframework.context.annotation.Bean;
import org.springframework.expression.EvaluationContext;
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.tuple.Tuple;
import org.springframework.tuple.TupleBuilder;
import org.tensorflow.Tensor;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

 * A processor that evaluates a machine learning model stored in TensorFlow's ProtoBuf format.
 * Processor uses a {@link TensorflowInputConverter} to convert the input data into TensorFlow model input format (called
 * feeds). The input converter converts the input {@link Message} into key/value {@link Map},
 * where the Key corresponds to a model input placeholder (feed) and the content is {@link org.tensorflow.DataType}
 * compliant value. The default converter implementation expects either Map payload.
 * The {@link TensorflowInputConverter} can be extended and customized.
 * Processor's output uses the {@link TensorflowOutputConverter} to convert the computed {@link Tensor} result into a
 * serializable message. The default implementation converts the Tensor result into {@link Tuple} triple (see:
 * {@link TensorflowOutputConverter}).
 * The {@link TensorflowOutputConverter} can be extended and customized to provide a convenient data representations,
 * accustomed for a particular model (see TwitterSentimentTensorflowOutputConverter.java)
 * By default the inference result is returned in the outbound Message payload. If the saveResultInHeader property is
 * set to true then the inference result would be stored in the outbound Message header by name as set by
 * the getResultHeader property. In this case the message payload is the same like the inbound message payload.
 * @author Christian Tzolov
 * @author Artem Bilan
public class TensorflowCommonProcessorConfiguration implements AutoCloseable {

	public static final String ORIGINAL_INPUT_DATA = "original.input.data";

	private static final Log logger = LogFactory.getLog(TensorflowCommonProcessorConfiguration.class);

	private EvaluationContext evaluationContext;

	private TensorflowCommonProcessorProperties properties;

	private TensorflowInputConverter tensorflowInputConverter;

	private TensorflowOutputConverter tensorflowOutputConverter;

	private TensorFlowService tensorFlowService;

	@ServiceActivator(inputChannel = Processor.INPUT, outputChannel = Processor.OUTPUT)
	public Object evaluate(Message<?> input) {

		Object inputData =
				this.properties.getExpression() == null
						? input.getPayload()
						: this.properties.getExpression().getValue(this.evaluationContext, input, Object.class);

		// The processorContext allows to convey metadata from the Input to Output converter.
		Map<String, Object> processorContext = new ConcurrentHashMap<>();

		Map<String, Object> inputDataMap = this.tensorflowInputConverter.convert(inputData, processorContext);

		Tensor outputTensor = this.tensorFlowService.evaluate(inputDataMap, this.properties.getModelFetch(),

		Object outputData = tensorflowOutputConverter.convert(outputTensor, processorContext);

		switch (this.properties.getMode()) {

			case tuple:
				TupleBuilder outTupleBuilder = TupleBuilder.tuple().put(properties.getOutputName(), outputData);

				Object payload = input.getPayload();

				if (payload instanceof Tuple && ((Tuple) payload).hasFieldName(ORIGINAL_INPUT_DATA)) {
					// If the payload is already a tuple that contains ORIGINAL_INPUT_DATA entry then copy the
					// content of the input tuple in the new tuple to be returned.
					outTupleBuilder.putAll((Tuple) payload);
				else {
					// This is a new tuple so preserve the input data.
					outTupleBuilder.put(ORIGINAL_INPUT_DATA, payload);

				return outTupleBuilder.build();

			case header:
				return MessageBuilder
						.setHeader(this.properties.getOutputName(), outputData);

				return outputData;


	public TensorFlowService tensorFlowService() throws IOException {
		return new TensorFlowService(this.properties.getModel());

	@ConditionalOnMissingBean(name = "tensorflowOutputConverter")
	public TensorflowOutputConverter tensorflowOutputConverter() {
		// Default implementations serializes the Tensor into Tuple
		return new TensorflowOutputConverter<Tuple>() {

			public Tuple convert(Tensor tensor, Map<String, Object> processorContext) {
				return TensorTupleConverter.toTuple(tensor);

	@ConditionalOnMissingBean(name = "tensorflowInputConverter")
	public TensorflowInputConverter tensorflowInputConverter() {
		return new TensorflowInputConverter() {

			public Map<String, Object> convert(Object input, Map<String, Object> processorContext) {

				if (input instanceof Map) {
					return (Map<String, Object>) input;

				throw new MessageConversionException("Unsupported input format: " + input);

	public void close() throws Exception {
		logger.info("Close TensorflowCommonProcessorConfiguration");