/*
 * Copyright 2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.stream.app.pose.estimation.processor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Tensor;

import org.springframework.cloud.stream.app.pose.estimation.model.Body;
import org.springframework.cloud.stream.app.pose.estimation.model.Limb;
import org.springframework.cloud.stream.app.pose.estimation.model.Model;
import org.springframework.cloud.stream.app.pose.estimation.model.Part;
import org.springframework.cloud.stream.app.tensorflow.processor.TensorflowOutputConverter;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
 * Credits:
 *  - https://arvrjourney.com/human-pose-estimation-using-openpose-with-tensorflow-part-2-e78ab9104fc8
 *
 * @author Christian Tzolov
 */
public class PoseEstimationTensorflowOutputConverter implements TensorflowOutputConverter<List<Body>> {

	private static final Log logger = LogFactory.getLog(PoseEstimationTensorflowOutputConverter.class);
	private final String modelFetchOutput;
	private PoseEstimationProcessorProperties poseProperties;

	public PoseEstimationTensorflowOutputConverter(PoseEstimationProcessorProperties poseEstimationProcessorProperties,
			List<String> modelFetch) {
		Assert.isTrue(modelFetch.size() == 1, "A single model output is supported");
		this.modelFetchOutput = modelFetch.get(0);
		logger.info("Pose Estimation model fetch output: " + this.modelFetchOutput);
		this.poseProperties = poseEstimationProcessorProperties;
		logger.info("Pose Estimation properties: " + this.poseProperties);
	}

	@Override
	public List<Body> convert(Map<String, Tensor<?>> tensorMap, Map<String, Object> processorContext) {

		try (Tensor<Float> openPoseOutputTensor = tensorMap.get(this.modelFetchOutput).expect(Float.class)) {

			int height = (int) openPoseOutputTensor.shape()[1]; // = input image's height / 8;
			int width = (int) openPoseOutputTensor.shape()[2]; //  = input image's width / 8;
			int heatmapPafmapCount = (int) openPoseOutputTensor.shape()[3]; // HeatMapCount + PafMapCount = 57 layers

			Assert.isTrue(heatmapPafmapCount == 57, "Incorrect number of output tensor layer");

			// [H] [W] [Heat + PAF]
			float[][][] tensorData = openPoseOutputTensor.copyTo(new float[1][height][width][heatmapPafmapCount])[0];

			if (poseProperties.isDebugVisualisationEnabled()) {
				byte[] inputImage = (byte[]) processorContext.get("inputImage");
				DebugVisualizationUtility.visualizeAllPafHeatMapChannels(inputImage, tensorData,
						poseProperties.getDebugVisualizationOutputPath() + "/PosePartHeatMap.jpg");
				DebugVisualizationUtility.visualizeAllPafChannels(inputImage, tensorData,
						poseProperties.getDebugVisualizationOutputPath() + "/PosePafField.jpg");
			}

			// -------------------------------------------------------------------------------------------------------
			// 1. Select the Part instances with higher confidence
			// -------------------------------------------------------------------------------------------------------

			// Perform non-maximum suppression on the detection confidence (e.g. heatmap) maps to obtain a discrete
			// set of part candidate locations. For each part, several candidates could appear, due to multiple
			// people in the image or false positives.

			Map<Model.PartType, List<Part>> parts = new HashMap<>();

			for (Model.PartType partType : Model.Body.getPartTypes()) {
				List<Part> partsPerType = findHighConfidenceParts(partType, height, width, tensorData);
				parts.put(partType, partsPerType);
			}
			if (poseProperties.isDebugVisualisationEnabled()) {
				DebugVisualizationUtility.visualizePartCandidates((byte[]) processorContext.get("inputImage"),
						parts, poseProperties.getDebugVisualizationOutputPath() + "/PosePartCandidates.jpg");
			}
			// -------------------------------------------------------------------------------------------------------
			// 2. Connect the selected Parts into Limbs
			// -------------------------------------------------------------------------------------------------------

			// Part candidates define a large set of possible limbs. Score each candidate Limb using the line integral
			// computation on the PAF.

			Map<Model.LimbType, List<Limb>> limbs = new HashMap<>();

			// For every Limb Type, retrieve the "from" and "to" Part Types. Retrieve all part Part instances for
			// those types and find the relationships (e.g. limbs) between them.
			for (Model.LimbType limbType : Model.Body.getLimbTypes()) {

				// Limb's "formPartType" Parts candidates.
				List<Part> fromParts = parts.get(limbType.getFromPartType());

				// Limb's "toPartType" Parts candidates.
				List<Part> toParts = parts.get(limbType.getToPartType());

				if (!CollectionUtils.isEmpty(fromParts) && !CollectionUtils.isEmpty(toParts)) {
					// Candidates are sorted by PAF score.
					PriorityQueue<Limb> limbCandidatesQueue = findLimbCandidates(limbType, fromParts, toParts, tensorData);

					// Determine the final Limb instances
					limbs.put(limbType, selectFinalLimbs(limbType, limbCandidatesQueue));
				}
			}

			if (poseProperties.isDebugVisualisationEnabled()) {
				DebugVisualizationUtility.visualizeLimbCandidates((byte[]) processorContext.get("inputImage"),
						limbs, poseProperties.getDebugVisualizationOutputPath() + "/LimbCandidates.jpg");
			}

			// ---------------------------------------------------------------
			// 3. Assembles the selected Limbs and Parts into Bodies
			// ---------------------------------------------------------------

			List<Body> bodies = assembleBodies(limbs);

			return bodies;
		}
	}

	/**
	 * For a {@link org.springframework.cloud.stream.app.pose.estimation.model.Model.PartType} identifies the
	 * body parts locations that have higher confidence to belong to the requested type.
	 *
	 * The Non-maximum Suppression (NMS) algorithm is used to extract the parts locations out of a heatmap and
	 * to suppress part overlapping.
	 *
	 * @param partType Part Type for which part candidates will e searched
	 * @param height Image height (1/8 of the input image height)
	 * @param width Image width (1/8 of the input image height)
	 * @param outputTensor The output tensor contains contains 18 part confidence maps (e.g. heatmaps).
	 *                        Each confidence map is a 2D representation of the belief that a particular body
	 *                        Part occurs at each pixel location.
	 * @return Returns a list of part candidates for the given Part Type.
	 */
	private List<Part> findHighConfidenceParts(Model.PartType partType, int height, int width,
			float[][][] outputTensor) {

		final int minNmsRadius = -(poseProperties.getNmsWindowSize() - 1) / 2;
		final int maxNmsRadius = (poseProperties.getNmsWindowSize() + 1) / 2;

		List<Part> partsPerType = new ArrayList<>();

		for (int y = Math.abs(minNmsRadius); y < height - maxNmsRadius; y++) {
			for (int x = Math.abs(minNmsRadius); x < width - maxNmsRadius; x++) {
				float maxPartScore = 0;
				for (int stepY = minNmsRadius; stepY < maxNmsRadius; stepY++) {
					for (int stepX = minNmsRadius; stepX < maxNmsRadius; stepX++) {
						maxPartScore = Math.max(maxPartScore, outputTensor[y + stepY][x + stepX][partType.getId()]);
					}
				}
				if (maxPartScore > poseProperties.getNmsThreshold()) {
					if (maxPartScore == outputTensor[y][x][partType.getId()]) {
						// Add another name center to the list (e.g. remember the cell with the higher score)
						partsPerType.add(new Part(partType, partsPerType.size(), y, x, maxPartScore));
					}
				}
			}
		}

		return partsPerType;
	}

	/**
	 *
	 * The Part Affinity Field (PAF) is a 2D vector field for each limb. For each pixel in the area belonging to a
	 * particular limb, a 2D vector encodes the direction that points from one part of the limb to the other.
	 * Each type of limb has a corresponding affinity field joining its two associated body parts.
	 *
	 * @param limbType Limb Type to find limb candidates form.
	 * @param fromParts
	 * @param toParts
	 * @param outputTensor
	 * @return Returns a list of Limb candidates sorted by their total PAF score in a descending order.
	 */
	private PriorityQueue<Limb> findLimbCandidates(Model.LimbType limbType, List<Part> fromParts, List<Part> toParts,
			float[][][] outputTensor) {

		// Use priority queue to keeps the limb instance candidates in descending order.
		int initialSize = (fromParts.size() * toParts.size()) / 2 + 1;
		PriorityQueue<Limb> limbCandidatesQueue = new PriorityQueue<>(initialSize,
				(limb1, limb2) -> {
					if (limb1.getPafScore() == limb2.getPafScore())
						return 0;
					return (limb1.getPafScore() > limb2.getPafScore()) ? -1 : 1;
				});

		// For every {from -> to} pair compute a line integral over the Limb-PAF vector field toward the line
		// connecting both Parts. Computed value is used as a Limb candidate score. The higher the value the
		// higher the chance for connection between those Parts.
		for (Part fromPart : fromParts) {
			for (Part toPart : toParts) {

				float deltaX = toPart.getY() - fromPart.getY();
				float deltaY = toPart.getX() - fromPart.getX();
				float norm = (float) Math.sqrt(deltaX * deltaX + deltaY * deltaY);

				// Skip self-pointing edges (e.g. fromPartInstance == toPartInstance)
				if (norm > 1e-12) {

					float dx = deltaX / norm;
					float dy = deltaY / norm;

					int STEP_PAF = 10;
					float pafScores[] = new float[STEP_PAF];
					int stepPafScoreCount = 0;
					float totalPafScore = 0.0f;
					for (int t = 0; t < STEP_PAF; t++) {
						int tx = (int) ((float) fromPart.getY() + (t * deltaX / STEP_PAF) + 0.5);
						int ty = (int) ((float) fromPart.getX() + (t * deltaY / STEP_PAF) + 0.5);

						float pafScoreX = outputTensor[tx][ty][limbType.getPafIndexX()];
						float pafScoreY = outputTensor[tx][ty][limbType.getPafIndexY()];

						pafScores[t] = (dy * pafScoreX) + (dx * pafScoreY);

						totalPafScore += pafScores[t];

						// Filter out the step PAF scores below a given, pre-defined stepPafScoreThreshold
						if (pafScores[t] > poseProperties.getStepPafScoreThreshold()) {
							stepPafScoreCount++;
						}
					}

					if (totalPafScore > poseProperties.getTotalPafScoreThreshold()
							&& stepPafScoreCount >= poseProperties.getPafCountThreshold()) {
						limbCandidatesQueue.add(
								new Limb(limbType, totalPafScore, fromPart, toPart));
					}
				}
			}
		}

		return limbCandidatesQueue;
	}

	/**
	 * From all possible limb candidates for a given Limb Type, select those that maximize the total PAF score.
	 * The algorithm starts from the limb candidates with higher PAF score. Also the algorithm tracks the parts
	 * already assigned t a final limbs and rejects limb candidates with already assigned parts.
	 *
	 * @param limbType Limb Type for which final limbs a selected.
	 * @param limbCandidatesQueue possible Limb candidates, sorted by total PAF score in a descending order.
	 * @return Returns the final list of Limbs for a given {@link org.springframework.cloud.stream.app.pose.estimation.model.Model.LimbType}
	 */
	private List<Limb> selectFinalLimbs(Model.LimbType limbType, PriorityQueue<Limb> limbCandidatesQueue) {

		List<Limb> finalLimbs = new ArrayList<>();

		// Parts assigned to final limbs.
		Set<Part> assignedParts = new HashSet<>();

		// Start from the candidates with higher PAF score and progress in descending order
		while (!limbCandidatesQueue.isEmpty()) {

			Limb limbCandidate = limbCandidatesQueue.poll();

			Assert.isTrue(limbType == limbCandidate.getLimbType(), "Incorrect Limb Type!");

			// Ignore candidate limbs with parts already assigned a final Limb from earlier iteration.
			if (!assignedParts.contains(limbCandidate.getFromPart())
					&& !assignedParts.contains(limbCandidate.getToPart())) {

				// Make the candidate final.
				finalLimbs.add(limbCandidate);

				// Mark limb's parts as assigned.
				assignedParts.add(limbCandidate.getFromPart());
				assignedParts.add(limbCandidate.getToPart());
			}
		}

		return finalLimbs;
	}

	/**
	 * Grows the body out of it parts
	 * @param limbsMap Limb candidates to connected into bodies
	 * @return Final list of body postures
	 */
	private List<Body> assembleBodies(Map<Model.LimbType, List<Limb>> limbsMap) {

		AtomicInteger bodyId = new AtomicInteger();

		Map<Part, Body> partToBodyIndex = new ConcurrentHashMap<>();

		for (Model.LimbType limbType : limbsMap.keySet()) {
			for (Limb limb : limbsMap.get(limbType)) {

				Body fromBody = partToBodyIndex.get(limb.getFromPart());
				Body toBody = partToBodyIndex.get(limb.getToPart());

				Body bodyCandidate;

				if (fromBody == null && toBody == null) {
					bodyCandidate = new Body(bodyId.getAndIncrement());
				}
				else if (fromBody != null && toBody != null) {
					bodyCandidate = fromBody;
					if (!fromBody.equals(toBody)) {
						bodyCandidate.getLimbs().addAll(toBody.getLimbs());
						bodyCandidate.getParts().addAll(toBody.getParts());
						toBody.getParts().forEach(p -> partToBodyIndex.put(p, bodyCandidate));
					}
				}
				else {
					bodyCandidate = (fromBody != null) ? fromBody : toBody;
				}

				bodyCandidate.addLimb(limb);
				partToBodyIndex.put(limb.getFromPart(), bodyCandidate);
				partToBodyIndex.put(limb.getToPart(), bodyCandidate);
			}
		}

		// Filter out the body duplicates and bodies with too few parts
		List<Body> bodies = partToBodyIndex.values().stream()
				.distinct()
				.filter(body -> body.getParts().size() > poseProperties.getMinBodyPartCount())
				.collect(Collectors.toList());

		return bodies;
	}

}