/*
 * Copyright 2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package net.tzolov.cv.mtcnn;

import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import javax.imageio.ImageIO;

import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgproc;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner;
import org.tensorflow.framework.ConfigProto;

import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.util.Assert;

import static net.tzolov.cv.mtcnn.MtcnnUtil.CHANNEL_COUNT;
import static net.tzolov.cv.mtcnn.MtcnnUtil.C_ORDERING;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;

/**
 * @author Christian Tzolov
 */
public class MtcnnService {

	// mtcnn models, frozen from the https://github.com/ipazc/mtcnn project
	// Note: if you enable this you have to change the output labels in the rnet and onet below.
	//public static final String TF_PNET_MODEL_URI = "classpath:/model/pnet_graph.proto";
	//public static final String TF_RNET_MODEL_URI = "classpath:/model/rnet_graph.proto";
	//public static final String TF_ONET_MODEL_URI = "classpath:/model/onet_graph.proto";

	// mtcnn models, frozen from the https://github.com/davidsandberg/facenet/tree/master/src/align project
	public static final String TF_PNET_MODEL_URI = "classpath:/model2/pnet_graph.proto";
	public static final String TF_RNET_MODEL_URI = "classpath:/model2/rnet_graph.proto";
	public static final String TF_ONET_MODEL_URI = "classpath:/model2/onet_graph.proto";

	private final Java2DNativeImageLoader imageLoader;

	private final GraphRunner proposeNetGraphRunner;
	private final GraphRunner refineNetGraphRunner;
	private final GraphRunner outputNetGraphRunner;

	//private final SameDiff outputNetGraph;
	//private final SameDiff refineNetGraph;
	//private final SameDiff proposeNetGraph;

	private final int minFaceSize;
	private final double scaleFactor;
	private final double[] stepsThreshold;


	public MtcnnService() {
		this(20, 0.709, new double[] { 0.6, 0.7, 0.7 });
	}

	public MtcnnService(int minFaceSize, double scaleFactor, double[] stepsThreshold) {
		this.minFaceSize = minFaceSize;
		this.scaleFactor = scaleFactor;
		this.stepsThreshold = stepsThreshold;

		this.imageLoader = new Java2DNativeImageLoader();

		this.proposeNetGraphRunner = this.createGraphRunner(TF_PNET_MODEL_URI, "pnet/input");
		this.refineNetGraphRunner = this.createGraphRunner(TF_RNET_MODEL_URI, "rnet/input");
		this.outputNetGraphRunner = this.createGraphRunner(TF_ONET_MODEL_URI, "onet/input");

		// Experimental
		//proposeNetGraph = TFGraphMapper.getInstance().importGraph(new DefaultResourceLoader().getResource(TF_PNET_MODEL_URI).getInputStream());
		//refineNetGraph = TFGraphMapper.getInstance().importGraph(new DefaultResourceLoader().getResource(TF_RNET_MODEL_URI).getInputStream());
		//outputNetGraph = TFGraphMapper.getInstance().importGraph(new DefaultResourceLoader().getResource(TF_ONET_MODEL_URI).getInputStream());
	}

	private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLabel) {
		try {
			return new GraphRunner(
					IOUtils.toByteArray(new DefaultResourceLoader().getResource(tensorflowModelUri).getInputStream()),
					Arrays.asList(inputLabel),
					ConfigProto.getDefaultInstance());
		}
		catch (IOException e) {
			throw new IllegalStateException(String.format("Failed to load TF model [%s] and input [%s]:",
					tensorflowModelUri, inputLabel), e);
		}
	}

	/**
	 * Detects faces in an image, and returns bounding boxes and points for them.
	 * @param imageUri Uri of the image to detect
	 * @return Array of face bounding boxes found in the image
	 * @throws IOException Incorrect image Uri.
	 */
	public FaceAnnotation[] faceDetection(String imageUri) throws IOException {
		// [ 3 x H x W ]
		INDArray image = this.imageLoader.asMatrix(new DefaultResourceLoader().getResource(imageUri).getInputStream())
				.get(point(0), all(), all(), all()).dup();
		return faceDetection(image);
	}

	public FaceAnnotation[] faceDetection(BufferedImage bImage) throws IOException {
		INDArray ndImage3HW = this.imageLoader.asMatrix(bImage).get(point(0), all(), all(), all());
		return faceDetection(ndImage3HW);
	}

	public FaceAnnotation[] faceDetection(byte[] byteImage, int h, int w) throws IOException {
		INDArray ndImage3HW = Nd4j.create(MtcnnUtil.imageByteToFloatArray(byteImage))
				.reshape(new int[] { h, w, 3 })
				.permutei(2, 0, 1);
		return faceDetection(ndImage3HW);
	}

	/**
	 * Detects faces for byte encoded input images. Supports only byte arrays exported from with their image
	 * formats e.g. ImageIO.write(bufferImage, format) or MtcnnUtil.toByteArray(bi2, "png")
	 * @param byteImage Input image encoded in bytes along with its image format spec.
	 * @return Array of face bounding boxes found in the image
	 * @throws IOException Incorrect image Uri.
	 */
	public FaceAnnotation[] faceDetection(byte[] byteImage) throws IOException {
		ByteArrayInputStream is = new ByteArrayInputStream(byteImage);
		BufferedImage bufferedImage = ImageIO.read(is);
		bufferedImage = MtcnnUtil.to3ByteBGR(bufferedImage);
		return faceDetection(bufferedImage);
	}

	/**
	 * Detects faces in an image, and returns bounding boxes and points for them.
	 * @param image3HW image to detect the faces in. Expected dimensions [ 3 x H x W ]
	 * @return Array of face bounding boxes found in the image
	 */
	public FaceAnnotation[] faceDetection(INDArray image3HW) throws IOException {

		INDArray[] outputStageResult = this.rawFaceDetection(image3HW);

		// Convert result into Bounding Box array
		INDArray totalBoxes = outputStageResult[0];
		INDArray points = outputStageResult[1];
		//if (!totalBoxes.isEmpty() && totalBoxes.size(0) > 1) { // 1.0.0-beta2
		if (!totalBoxes.isEmpty() && totalBoxes.size(0) > 0) { // 1.0.0-SNAPSHOT
			points = points.transpose();
		}

		return MtcnnUtil.toFaceAnnotation(totalBoxes, points);
	}

	public INDArray[] faceAlignment(INDArray image, FaceAnnotation[] bboxes, int margin, int alignedImageSize, boolean preWhiten) throws IOException {
		INDArray[] alignments = new INDArray[bboxes.length];
		for (int i = 0; i < bboxes.length; i++) {
			alignments[i] = this.faceAlignment(image, bboxes[i], margin, alignedImageSize, preWhiten);
		}
		return alignments;
	}

	public INDArray faceAlignment(INDArray image, FaceAnnotation faceAnnotation, int margin, int alignedImageSize, boolean preWhiten) throws IOException {
		FaceAnnotation.BoundingBox bbox = faceAnnotation.getBoundingBox();
		int x = bbox.getX();
		int y = bbox.getY();
		int w = bbox.getW();
		int h = bbox.getH();

		int y1 = Math.max(y - (margin / 2), 0);
		int x1 = Math.max(x - (margin / 2), 0);
		int y2 = Math.min((y + h) + margin / 2, (int) image.shape()[1]);
		int x2 = Math.min((x + w) + margin / 2, (int) image.shape()[2]);

		INDArray croppedImage = MtcnnUtil.crop(image, x1, x2, y1, y2);

		croppedImage = this.resize(croppedImage,
				new opencv_core.Size(alignedImageSize, alignedImageSize)); // W, H

		if (preWhiten) {
			croppedImage = MtcnnUtil.preWhiten(croppedImage);
		}

		return croppedImage;
	}

	/**
	 * Detect faces and related points.
	 * @param image3HW input image with dimensions [C x H x W] (e.g. channels first)
	 * @return Two INDArray elements representing the Total Boxes found and the related points.
	 * @throws IOException
	 */
	public INDArray[] rawFaceDetection(INDArray image3HW) throws IOException {

		WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder()
				.initialSize(10 * 1024L * 1024L)
				.policyAllocation(AllocationPolicy.STRICT)
				.policyLearning(LearningPolicy.NONE)
				.build();

		try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) {

			Assert.isTrue(image3HW.rank() == 3, "The input image is expected to have [0, Channels, Width, Height] dimensions");
			Assert.isTrue(image3HW.shape()[0] == 3, "The input image is expected to have channel count at dimension 0");

			// Compute the scale pyramid
			int height = (int) image3HW.size(1);
			int width = (int) image3HW.size(2);

			List<Double> scales = MtcnnUtil.computeScalePyramid(height, width, this.minFaceSize, this.scaleFactor);

			// Stage One
			Object[] stageOneResult = this.preparationStage(image3HW, scales);

			// Reorder image dimensions from [3,H,W] to [H,W,3]
			image3HW = image3HW.permute(1, 2, 0);

			// Stage Two
			INDArray totalBoxes = this.refinementStage(image3HW, (INDArray) stageOneResult[0], (MtcnnUtil.PadResult) stageOneResult[1]);

			// Stage Three
			INDArray[] stageThreeResult = this.outputStage(image3HW, totalBoxes);

			return stageThreeResult;
		}
	}

	/**
	 * STAGE 1
	 *
	 * @param image3HW
	 * @param scales
	 * @return
	 * @throws IOException
	 */
	private Object[] preparationStage(INDArray image3HW, List<Double> scales) throws IOException {

		INDArray totalBoxes = Nd4j.empty();
		MtcnnUtil.PadResult padResult = null;

		double imageHeight = image3HW.size(1);
		double imageWidth = image3HW.size(2);

		for (Double scale : scales) {

			int newWidth = (int) Math.ceil(imageWidth * scale);
			int newHeight = (int) Math.ceil(imageHeight * scale);

			//[0, W, H, 3]
			INDArray image0WH3 = resize(image3HW, new opencv_core.Size(newWidth, newHeight)).permute(0, 3, 2, 1).dup();
			//image0WH3 = image0WH3.subi(127.5).muli(0.0078125);
			image0WH3 = image0WH3.sub(127.5).mul(0.0078125);

			// img_x = np.expand_dims(scaled_image, 0)
			// img_y = np.transpose(img_x, (0, 2, 1, 3))

			//this.proposeNetGraph.associateArrayWithVariable(image0WH3, this.proposeNetGraph.variableMap().get("pnet/input"));
			//List<DifferentialFunction> proposeNetResults = this.proposeNetGraph.exec().getRight();
			//INDArray out0 = proposeNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("pnet/conv4-2/BiasAdd"))
			//		.findFirst().get().outputVariable().getArr(); //.permutei(0, 2, 1, 3);
			//INDArray out1 = proposeNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("pnet/prob1"))
			//		.findFirst().get().outputVariable().getArr(); //.permutei(0, 2, 1, 3);

			Map<String, INDArray> resultMap = this.proposeNetGraphRunner.run(Collections.singletonMap("pnet/input", image0WH3));
			INDArray out0 = resultMap.get("pnet/conv4-2/BiasAdd");//.permutei(0, 2, 1, 3);
			INDArray out1 = resultMap.get("pnet/prob1");//.permutei(0, 2, 1, 3);

			// boxes, _ = self.__generate_bounding_box(out1[0, :, :, 1].copy(),
			//    out0[0, :, :, :].copy(), scale, self.__steps_threshold[0])
			INDArray boxes = MtcnnUtil.generateBoundingBox(out1.get(point(0), all(), all(), point(1)),
					out0.get(point(0), all(), all(), all()), scale, this.stepsThreshold[0])[0];

			if (!boxes.isEmpty()) {
				INDArray pick = MtcnnUtil.nonMaxSuppression(boxes, 0.5, MtcnnUtil.NonMaxSuppressionType.Union);
				if (boxes.length() > 0 && pick.length() > 0 && !pick.isEmpty()) {
					boxes = boxes.get(new SpecifiedIndex(pick.toLongVector()), all());
					if (totalBoxes.isEmpty()) {
						totalBoxes = boxes;
					}
					else {
						totalBoxes = MtcnnUtil.append(totalBoxes, boxes, 0);
					}
				}
			}
		}

		long numBoxes = totalBoxes.isEmpty() ? 0 : totalBoxes.shape()[0];
		if (numBoxes > 0) {
			INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes, 0.7, MtcnnUtil.NonMaxSuppressionType.Union);
			totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

			// regw = total_boxes[:, 2] - total_boxes[:, 0]
			// regh = total_boxes[:, 3] - total_boxes[:, 1]
			INDArray x2 = totalBoxes.get(all(), point(2));
			INDArray x1 = totalBoxes.get(all(), point(0));
			INDArray y2 = totalBoxes.get(all(), point(3));
			INDArray y1 = totalBoxes.get(all(), point(1));

			INDArray regw = x2.sub(x1);
			INDArray regh = y2.sub(y1);

			// qq1 = total_boxes[:, 0] + total_boxes[:, 5] * regw
			// qq2 = total_boxes[:, 1] + total_boxes[:, 6] * regh
			// qq3 = total_boxes[:, 2] + total_boxes[:, 7] * regw
			// qq4 = total_boxes[:, 3] + total_boxes[:, 8] * regh
			INDArray qq1 = x1.add(totalBoxes.get(all(), point(5)).mul(regw));
			INDArray qq2 = y1.add(totalBoxes.get(all(), point(6)).mul(regh));
			INDArray qq3 = x2.add(totalBoxes.get(all(), point(7)).mul(regw));
			INDArray qq4 = y2.add(totalBoxes.get(all(), point(8)).mul(regh));

			// total_boxes = np.transpose(np.vstack([qq1, qq2, qq3, qq4, total_boxes[:, 4]]))
			totalBoxes = Nd4j.hstack(qq1, qq2, qq3, qq4, totalBoxes.get(all(), point(4)));

			// total_boxes = self.__rerec(total_boxes.copy())
			totalBoxes = MtcnnUtil.rerec(totalBoxes, true);

			padResult = MtcnnUtil.pad(totalBoxes, (int) imageWidth, (int) imageHeight);
		}

		return new Object[] { totalBoxes, padResult };
	}

	/**
	 *  STAGE 2
	 *
	 * @param image
	 * @param totalBoxes
	 * @param padResult
	 * @return
	 * @throws IOException
	 */
	private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException {

		// num_boxes = total_boxes.shape[0]
		int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0];
		// if num_boxes == 0:
		//   return total_boxes, stage_status
		if (numBoxes == 0) {
			return totalBoxes;
		}

		INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24);

		//this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input"));
		//List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight();
		//INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2"))
		//		.findFirst().get().outputVariable().getArr();
		//INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1"))
		//		.findFirst().get().outputVariable().getArr();

		Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1));
		//INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2");  // for ipazc/mtcnn model
		INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2");
		INDArray out1 = resultMap.get("rnet/prob1");

		//  score = out1[1, :]
		INDArray score = out1.get(all(), point(1)).transposei();

		// ipass = np.where(score > self.__steps_threshold[1])
		INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]);
		//INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1]));

		if (ipass.isEmpty()) {
			totalBoxes = Nd4j.empty();
			return totalBoxes;
		}
		// total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)])
		INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4));
		INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1)
				: Nd4j.expandDims(score.get(ipass), 1);
		totalBoxes = Nd4j.hstack(b1, b2);

		// mv = out0[:, ipass[0]]
		INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

		// if total_boxes.shape[0] > 0:
		if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) {
			// pick = self.__nms(total_boxes, 0.7, 'Union')
			INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose();

			// total_boxes = total_boxes[pick, :]
			totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

			// total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick]))
			totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose());

			// total_boxes = self.__rerec(total_boxes.copy())
			totalBoxes = MtcnnUtil.rerec(totalBoxes, false);
		}

		return totalBoxes;
	}

	/**
	 * STAGE 3
	 *
	 * @param image
	 * @param totalBoxes
	 * @return
	 * @throws IOException
	 */
	private INDArray[] outputStage(INDArray image, INDArray totalBoxes) throws IOException {

		// num_boxes = total_boxes.shape[0]
		int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0];
		// if num_boxes == 0:
		//   return total_boxes, stage_status
		if (numBoxes == 0) {
			return new INDArray[] { totalBoxes, Nd4j.empty() };
		}

		// total_boxes = np.fix(total_boxes).astype(np.int32)
		totalBoxes = Transforms.floor(totalBoxes);

		// status = StageStatus(self.__pad(total_boxes.copy(), stage_status.width, stage_status.height),
		//                             width=stage_status.width, height=stage_status.height)

		MtcnnUtil.PadResult padResult = MtcnnUtil.pad(totalBoxes, (int) image.shape()[1], (int) image.shape()[0]);

		INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 48);

		//this.outputNetGraph.associateArrayWithVariable(tempImg1, this.outputNetGraph.variableMap().get("onet/input"));
		//List<DifferentialFunction> outputNetResults = this.outputNetGraph.exec().getRight();
		//
		//INDArray out0 = outputNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("onet/fc2-2/fc2-2"))
		//		.findFirst().get().outputVariable().getArr();
		//INDArray out1 = outputNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("onet/fc2-3/fc2-3"))
		//		.findFirst().get().outputVariable().getArr();
		//INDArray out2 = outputNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("onet/prob1"))
		//		.findFirst().get().outputVariable().getArr();

		Map<String, INDArray> resultMap = this.outputNetGraphRunner.run(Collections.singletonMap("onet/input", tempImg1));

		INDArray out0 = resultMap.get("onet/conv6-2/conv6-2");
		INDArray out1 = resultMap.get("onet/conv6-3/conv6-3");

		//INDArray out0 = resultMap.get("onet/fc2-2/fc2-2"); // for ipazc/mtcnn model
		//INDArray out1 = resultMap.get("onet/fc2-3/fc2-3"); // for ipazc/mtcnn model

		INDArray out2 = resultMap.get("onet/prob1");

		// score = out2[1, :]
		//INDArray score = out2.get(point(1), all());
		INDArray score = out2.get(all(), point(1)).transposei();

		// points = out1
		INDArray points = out1;

		// ipass = np.where(score > self.__steps_threshold[2])
		INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[2]);
		//INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[2]));

		if (ipass.isEmpty()) {
			return new INDArray[] { Nd4j.empty(), Nd4j.empty() };
		}

		// points = points[:, ipass[0]]
		points = points.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

		// total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)])
		INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4)).dup();
		INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1).dup()
				: Nd4j.expandDims(score.get(ipass).dup(), 1);
		totalBoxes = Nd4j.hstack(b1, b2);

		// mv = out0[:, ipass[0]]
		INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei();

		//  w = total_boxes[:, 2] - total_boxes[:, 0] + 1
		//  h = total_boxes[:, 3] - total_boxes[:, 1] + 1
		INDArray w = totalBoxes.get(all(), point(2)).dup().subi(totalBoxes.get(all(), point(0))).addi(1);
		INDArray h = totalBoxes.get(all(), point(3)).dup().subi(totalBoxes.get(all(), point(1))).addi(1);

		// points[0:5, :] = np.tile(w, (5, 1)) * points[0:5, :] + np.tile(total_boxes[:, 0], (5, 1)) - 1
		// points[5:10, :] = np.tile(h, (5, 1)) * points[5:10, :] + np.tile(total_boxes[:, 1], (5, 1)) - 1
		points.put(new INDArrayIndex[] { interval(0, 5), all() },
				Nd4j.repeat(w, 5)
						.muli(points.get(interval(0, 5), all()))
						.addi(Nd4j.repeat(totalBoxes.get(all(), point(0)), 5))
						.subi(1));

		points.put(new INDArrayIndex[] { interval(5, 10), all() },
				Nd4j.repeat(h, 5)
						.muli(points.get(interval(5, 10), all()))
						.addi(Nd4j.repeat(totalBoxes.get(all(), point(1)), 5))
						.subi(1));

		if (totalBoxes.shape()[0] > 0) {

			// total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv))
			totalBoxes = MtcnnUtil.bbreg(totalBoxes.dup(), mv.transpose());

			// pick = self.__nms(total_boxes.copy(), 0.7, 'Min')
			INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Min).transpose();

			//  total_boxes = total_boxes[pick, :]
			totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all());

			// points = points[:, pick]
			points = points.get(all(), new SpecifiedIndex(pick.toLongVector()));
		}

		return new INDArray[] { totalBoxes, points };
	}

	private INDArray computeTempImage(INDArray image, int numBoxes, MtcnnUtil.PadResult padResult, int size) throws IOException {

		//  tempimg = np.zeros(shape=(size, size, 3, num_boxes))
		INDArray tempImg = Nd4j.zeros(new int[] { size, size, CHANNEL_COUNT, numBoxes }, C_ORDERING);

		opencv_core.Size newSize = new opencv_core.Size(size, size);

		for (int k = 0; k < numBoxes; k++) {
			//tmp = np.zeros((int(stage_status.tmph[k]), int(stage_status.tmpw[k]), 3))
			INDArray tmp = Nd4j.zeros(new int[] { padResult.getTmph().getInt(k), padResult.getTmpw().getInt(k), CHANNEL_COUNT }, C_ORDERING);

			// tmp[stage_status.dy[k] - 1:stage_status.edy[k], stage_status.dx[k] - 1:stage_status.edx[k], :] = \
			//   img[stage_status.y[k] - 1:stage_status.ey[k], stage_status.x[k] - 1:stage_status.ex[k], :]
			tmp.put(new INDArrayIndex[] {
							interval(padResult.getDy().getInt(k) - 1, padResult.getEdy().getInt(k)),
							interval(padResult.getDx().getInt(k) - 1, padResult.getEdx().getInt(k)),
							all() },
					image.get(
							interval(padResult.getY().getInt(k) - 1, padResult.getEy().getInt(k)),
							interval(padResult.getX().getInt(k) - 1, padResult.getEx().getInt(k)),
							all()));

			// if tmp.shape[0] > 0 and tmp.shape[1] > 0 or tmp.shape[0] == 0 and tmp.shape[1] == 0:
			//    tempimg[:, :, :, k] = cv2.resize(tmp, (size, size), interpolation=cv2.INTER_AREA)
			if ((tmp.shape()[0] > 0 && tmp.shape()[1] > 0) || (tmp.shape()[0] == 0 && tmp.shape()[1] == 0)) {

				INDArray resizedImage = resize(tmp.permutei(2, 0, 1).dup(), newSize)
						.get(point(0), all(), all(), all()).permutei(1, 2, 0).dup();

				tempImg.put(new INDArrayIndex[] { all(), all(), all(), point(k) }, resizedImage);
			}
			else {
				return Nd4j.empty();
			}
		}

		// tempimg = (tempimg - 127.5) * 0.0078125
		tempImg = tempImg.subi(127.5).muli(0.0078125);

		// tempimg1 = np.transpose(tempimg, (3, 1, 0, 2))
		INDArray tempImg1 = tempImg.permutei(3, 1, 0, 2).dup();

		return tempImg1;
	}

	/**
	 * Resize an {@link INDArray} encoded image.
	 * @param imageCHW Image to resize. Expects [CHANNEL, HEIGHT, WIDTH] dimensions.
	 * @param newSizeWH new image size (w,h)
	 * @return Returns {@link INDArray} resized image with following dimensions [BATCH, WIDTH, HEIGHT, CHANNEL]
	 * @throws IOException
	 */
	public INDArray resize(INDArray imageCHW, opencv_core.Size newSizeWH) throws IOException {
		Assert.isTrue(imageCHW.size(0) == CHANNEL_COUNT, "Input image is expected to have the [3, W, H] dimensions");
		// Mat expects [C, H, W] dimensions
		opencv_core.Mat mat = imageLoader.asMat(imageCHW);
		opencv_imgproc.resize(mat, mat, newSizeWH, 0, 0, opencv_imgproc.CV_INTER_AREA);
		//[0, W, H, 3]
		return imageLoader.asMatrix(mat);
	}
}