package com.github.megachucky.kafka.streams.machinelearning.test; import static org.assertj.core.api.Assertions.assertThat; import java.io.IOException; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; import java.util.List; import java.util.Properties; import com.github.jukkakarvanen.kafka.streams.integration.utils.TestEmbeddedKafkaCluster; import com.github.jukkakarvanen.kafka.streams.integration.utils.TestKafkaStreams; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.serialization.StringSerializer; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.streams.KafkaStreams; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.StreamsConfig; import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; import org.apache.kafka.streams.kstream.KStream; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; import org.tensorflow.Tensor; import com.github.megachucky.kafka.streams.machinelearning.Kafka_Streams_TensorFlow_Image_Recognition_Example; /** * * @author Kai Waehner (www.kai-waehner.de) * * End-to-end integration test based on * {@link Kafka_Streams_TensorFlow_Image_Recognition_Example}, using an * embedded Kafka cluster and a TensorFlow CNN model. * * * */ public class Kafka_Streams_TensorFlow_Image_Recognition_Example_IntegrationTest { @ClassRule public static final EmbeddedKafkaCluster CLUSTER = new TestEmbeddedKafkaCluster(1); private static final String inputTopic = "ImageInputTopic"; private static final String outputTopic = "ImageOutputTopic"; // Prediction Value private static String imageClassification = "unknown"; @BeforeClass public static void startKafkaCluster() throws Exception { CLUSTER.createTopic(inputTopic); CLUSTER.createTopic(outputTopic); } @Test public void shouldRecognizeImages() throws Exception { // Images: 'unknown', Airliner, 'unknown', Butterfly List<String> inputValues = Arrays.asList("src/main/resources/TensorFlow_Images/trained_airplane_2.jpg", "src/main/resources/TensorFlow_Images/devil.png", "src/main/resources/TensorFlow_Images/trained_butterfly.jpg"); // ######################################################## // Step 1: Configure and start the processor topology. // ######################################################## Properties streamsConfiguration = new Properties(); streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "kafka-streams-tensorflow-image-recognition-integration-test"); streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); // Create TensorFlow object String modelDir = "src/main/resources/generatedModels/CNN_inception5h"; Path pathGraph = Paths.get(modelDir, "tensorflow_inception_graph.pb"); byte[] graphDef = Files.readAllBytes(pathGraph); Path pathModel = Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"); List<String> labels = Files.readAllLines(pathModel, Charset.forName("UTF-8")); // Configure Kafka Streams Application // Specify default (de)serializers for record keys and for record // values. streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); // In the subsequent lines we define the processing topology of the // Streams application. final StreamsBuilder builder = new StreamsBuilder(); // Construct a `KStream` from the input topic "AirlineInputTopic", where // message values // represent lines of text (for the sake of this example, we ignore // whatever may be stored // in the message keys). final KStream<String, String> imageInputLines = builder.stream(inputTopic); // Stream Processor (in this case 'foreach' to add custom logic, i.e. // apply the analytic model) imageInputLines.foreach((key, value) -> { imageClassification = "unknown"; String imageFile = value; Path pathImage = Paths.get(imageFile); byte[] imageBytes; try { imageBytes = Files.readAllBytes(pathImage); // Load and execute TensorFlow graph try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities); imageClassification = labels.get(bestLabelIdx); System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)", imageClassification, labelProbabilities[bestLabelIdx] * 100f)); } } catch (IOException e) { e.printStackTrace(); } }); // Transform message: Add prediction information KStream<String, Object> transformedMessage = imageInputLines .mapValues(value -> "Image Recognition: What is content of the picture? => " + imageClassification); // Send prediction information to Output Topic transformedMessage.to(outputTopic); // Start Kafka Streams Application to process new incoming messages from // Input Topic final KafkaStreams streams = new TestKafkaStreams(builder.build(), streamsConfiguration); streams.cleanUp(); streams.start(); System.out.println("Image Recognition Microservice is running..."); System.out.println("Input to Kafka Topic " + inputTopic + "; Output to Kafka Topic " + outputTopic); // ######################################################## // Step 2: Produce some input data to the input topic. // ######################################################## Properties producerConfig = new Properties(); producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); producerConfig.put(ProducerConfig.ACKS_CONFIG, "all"); producerConfig.put(ProducerConfig.RETRIES_CONFIG, 0); producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); IntegrationTestUtils.produceValuesSynchronously(inputTopic, inputValues, producerConfig, new MockTime()); // ######################################################## // Step 3: Verify the application's output data. // ######################################################## Properties consumerConfig = new Properties(); consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-streams-tensorflow-image-recognition-integration-test-standard-consumer"); consumerConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); List<KeyValue<String, String>> response = IntegrationTestUtils .waitUntilMinKeyValueRecordsReceived(consumerConfig, outputTopic, 3); streams.close(); assertThat(response).isNotNull(); assertThat(response.get(0).value).isEqualTo("Image Recognition: What is content of the picture? => airliner"); assertThat(response.get(1).value) .isNotEqualTo("Image Recognition: What is content of the picture? => airliner"); assertThat(response.get(2).value) .isEqualTo("Image Recognition: What is content of the picture? => cabbage butterfly"); } // ######################################################################################## // Private helper class for construction and execution of the pre-built // TensorFlow model // ######################################################################################## private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { try (Graph g = new Graph()) { GraphBuilder b = new GraphBuilder(g); // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip // // - The model was trained with images scaled to 224x224 pixels. // - The colors, represented as R, G, B in 1-byte each were // converted to // float using (value - Mean)/Scale. final int H = 224; final int W = 224; final float mean = 117f; final float scale = 1f; // Since the graph is being constructed once per execution here, we // can use a constant for the // input image. If the graph were to be re-used for multiple input // images, a placeholder would // have been more appropriate. final Output input = b.constant("input", imageBytes); final Output output = b .div(b.sub( b.resizeBilinear(b.expandDims(b.cast(b.decodeJpeg(input, 3), DataType.FLOAT), b.constant("make_batch", 0)), b.constant("size", new int[] { H, W })), b.constant("mean", mean)), b.constant("scale", scale)); try (Session s = new Session(g)) { return s.runner().fetch(output.op().name()).run().get(0); } } } private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { try (Graph g = new Graph()) { g.importGraphDef(graphDef); try (Session s = new Session(g); Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException(String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape))); } int nlabels = (int) rshape[1]; return result.copyTo(new float[1][nlabels])[0]; } } } private static int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) { best = i; } } return best; } // In the fullness of time, equivalents of the methods of this class should // be auto-generated from // the OpDefs linked into libtensorflow_jni.so. That would match what is // done in other languages // like Python, C++ and Go. static class GraphBuilder { GraphBuilder(Graph g) { this.g = g; } Output div(Output x, Output y) { return binaryOp("Div", x, y); } Output sub(Output x, Output y) { return binaryOp("Sub", x, y); } Output resizeBilinear(Output images, Output size) { return binaryOp("ResizeBilinear", images, size); } Output expandDims(Output input, Output dim) { return binaryOp("ExpandDims", input, dim); } Output cast(Output value, DataType dtype) { return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0); } Output decodeJpeg(Output contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg").addInput(contents).setAttr("channels", channels).build() .output(0); } Output constant(String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name).setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0); } } private Output binaryOp(String type, Output in1, Output in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); } private Graph g; } }