/* * Copyright 2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.stream.app.pose.estimation.processor; import java.awt.BasicStroke; import java.awt.Color; import java.awt.FontMetrics; import java.awt.Graphics2D; import java.awt.RenderingHints; import java.awt.Stroke; import java.awt.geom.AffineTransform; import java.awt.geom.Line2D; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.List; import javax.imageio.ImageIO; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.stream.app.pose.estimation.model.Body; import org.springframework.cloud.stream.app.pose.estimation.model.Limb; import org.springframework.cloud.stream.app.pose.estimation.model.Model; import org.springframework.cloud.stream.app.pose.estimation.model.Part; import org.springframework.cloud.stream.app.tensorflow.processor.DefaultOutputMessageBuilder; import org.springframework.cloud.stream.app.tensorflow.processor.TensorflowCommonProcessorProperties; import org.springframework.cloud.stream.app.tensorflow.util.GraphicsUtils; import org.springframework.integration.support.MessageBuilder; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.util.MimeTypeUtils; /** * Extends the {@link DefaultOutputMessageBuilder} with ability to to augment the input image with the * recognized poses. * * @author Christian Tzolov */ public class PoseEstimateOutputMessageBuilder extends DefaultOutputMessageBuilder { private static final Log logger = LogFactory.getLog(PoseEstimateOutputMessageBuilder.class); public static final String IMAGE_FORMAT = "jpg"; public static final Color DEFAULT_COLOR = new Color(167, 252, 0); private PoseEstimationProcessorProperties poseProperties; public PoseEstimateOutputMessageBuilder(PoseEstimationProcessorProperties poseProperties, TensorflowCommonProcessorProperties properties) { super(properties); this.poseProperties = poseProperties; } @Override public MessageBuilder<?> createOutputMessageBuilder(Message<?> inputMessage, Object computedScore) { Message<?> annotatedInput = inputMessage; List<Body> bodies = (List<Body>) computedScore; if (this.poseProperties.isDrawPoses()) { try { byte[] annotatedImage = drawPoses((byte[]) inputMessage.getPayload(), bodies); annotatedInput = MessageBuilder.withPayload(annotatedImage) .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.APPLICATION_OCTET_STREAM_VALUE) .build(); } catch (IOException e) { logger.error("Failed to draw the poses", e); } } return super.createOutputMessageBuilder(annotatedInput, toJson(bodies)); } private byte[] drawPoses(byte[] imageBytes, List<Body> bodies) throws IOException { if (bodies != null) { BufferedImage originalImage = ImageIO.read(new ByteArrayInputStream(imageBytes)); Graphics2D g = originalImage.createGraphics(); g.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON); g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); Stroke stroke = g.getStroke(); g.setStroke(new BasicStroke(this.poseProperties.getDrawLineWidth())); for (Body body : bodies) { for (Limb limb : body.getLimbs()) { Color limbColor = findLimbColor(body, limb); Part from = limb.getFromPart(); Part to = limb.getToPart(); if (limb.getLimbType() != Model.LimbType.limb17 && limb.getLimbType() != Model.LimbType.limb18) { g.setColor(limbColor); g.draw(new Line2D.Double(from.getNormalizedX(), from.getNormalizedY(), to.getNormalizedX(), to.getNormalizedY())); } g.setStroke(new BasicStroke(1)); drawPartOval(from, this.poseProperties.getDrawPartRadius(), g); drawPartOval(to, this.poseProperties.getDrawPartRadius(), g); g.setStroke(new BasicStroke(this.poseProperties.getDrawLineWidth())); } } g.setStroke(stroke); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ImageIO.write(originalImage, IMAGE_FORMAT, baos); baos.flush(); imageBytes = baos.toByteArray(); baos.close(); g.dispose(); } return imageBytes; } private Color findLimbColor(Body body, Limb limb) { Color limbColor = DEFAULT_COLOR; ; switch (this.poseProperties.getBodyDrawingColorSchema()) { case bodyInstance: limbColor = GraphicsUtils.getClassColor(body.getBodyId() * 3); break; case limbType: limbColor = GraphicsUtils.LIMBS_COLORS[limb.getLimbType().getId()]; break; case monochrome: limbColor = DEFAULT_COLOR; break; } return limbColor; } private void drawPartOval(Part part, int radius, Graphics2D g) { int partX = part.getNormalizedX(); int partY = part.getNormalizedY(); g.setColor(GraphicsUtils.LIMBS_COLORS[part.getPartType().getId()]); g.fillOval(partX - radius, partY - radius, 2 * radius, 2 * radius); if (this.poseProperties.isDrawPartLabels()) { String label = part.getPartType().getId() + ":" + part.getPartType().name(); FontMetrics fm = g.getFontMetrics(); int labelX = partX + 5; int labelY = partY - 5; AffineTransform t = g.getTransform(); g.setTransform(AffineTransform.getRotateInstance(Math.toRadians(-35), labelX, labelY)); g.drawString(label, labelX, labelY); g.setTransform(t); } } private String toJson(List<Body> bodies) { try { return new ObjectMapper().writeValueAsString(bodies); } catch (JsonProcessingException e) { logger.error("Failed to encode the bodies into JSON message", e); } return "ERROR"; } }