package cz.cuni.lf1.lge.ThunderSTORM.drift;

import cz.cuni.lf1.lge.ThunderSTORM.UI.GUI;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.Molecule;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.MoleculeDescriptor;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.RadialSymmetryFitter;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.SubImage;
import cz.cuni.lf1.lge.ThunderSTORM.rendering.ASHRendering;
import cz.cuni.lf1.lge.ThunderSTORM.rendering.RenderingMethod;
import cz.cuni.lf1.lge.ThunderSTORM.results.ModifiedLoess;
import cz.cuni.lf1.lge.ThunderSTORM.util.MathProxy;
import cz.cuni.lf1.lge.ThunderSTORM.util.Padding;
import cz.cuni.lf1.lge.ThunderSTORM.util.VectorMath;
import ij.IJ;
import ij.ImageStack;
import ij.process.FHT;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math3.util.MathArrays;

import java.awt.geom.Point2D;
import java.util.Arrays;

/**
 *
 */
public class CorrelationDriftEstimator {

    /**
     *
     * @param x [px]
     * @param y [px]
     * @param frame
     * @param steps
     * @param smoothingBandwidth - typically 0.25 - 0.5
     * @param roiWidth - [px] width of the original image or -1 for max(x)
     * @param roiHeight - [px] height of the original image or -1 for max(y)
     */
    public static CrossCorrelationDriftResults estimateDriftFromCoords(
            double[] x, double[] y, double[] frame,
            final int steps,
            final double magnification,
            final double smoothingBandwidth,
            int roiWidth, int roiHeight,
            boolean saveCorrelationImages) {

        final BinningResults bins = binResultByFrame(x, y, frame, steps);
        GUI.checkIJEscapePressed();
        int imageWidth = (roiWidth < 1) ? (int) MathProxy.ceil(VectorMath.max(x)) : roiWidth;
        int imageHeight = (roiHeight < 1) ? (int) MathProxy.ceil(VectorMath.max(y)) : roiHeight;

        RenderingMethod renderer = new ASHRendering.Builder().roi(0, imageWidth, 0, imageHeight).resolution(1 / magnification).shifts(2).build();
        RenderingMethod lowResRenderer = new ASHRendering.Builder().roi(0, imageWidth, 0, imageHeight).resolution(1).shifts(2).build();

        FloatProcessor firstImage = (FloatProcessor) renderer.getRenderedImage(bins.xBinnedByFrame[0], bins.yBinnedByFrame[0], null, null, null).getProcessor().convertToFloat();
        int paddedSize = MathProxy.nextPowerOf2(MathProxy.max(firstImage.getWidth(), firstImage.getHeight()));
        FHT firstImageFFT = createPaddedFFTImage(firstImage, paddedSize);

        FloatProcessor lowResFirstImage = (FloatProcessor) lowResRenderer.getRenderedImage(bins.xBinnedByFrame[0], bins.yBinnedByFrame[0], null, null, null).getProcessor().convertToFloat();
        int lowResPaddedSize = MathProxy.nextPowerOf2(MathProxy.max(lowResFirstImage.getWidth(), lowResFirstImage.getHeight()));
        FHT lowResFirstImageFFT = createPaddedFFTImage(lowResFirstImage, lowResPaddedSize);

        ImageStack correlationImages = null;
        if(saveCorrelationImages) {
            correlationImages = new ImageStack(paddedSize, paddedSize);
        }

        double[] driftXofImage = new double[steps];
        double[] driftYofImage = new double[steps];
        driftXofImage[0] = 0;
        driftYofImage[0] = 0;

        for(int i = 1; i < steps; i++) {
            IJ.showProgress((double) i / (double) (steps - 1));
            IJ.showStatus("Processing part " + i + " from " + (steps - 1) + "...");
            GUI.checkIJEscapePressed();

            FloatProcessor nextImage = (FloatProcessor) renderer.getRenderedImage(bins.xBinnedByFrame[i], bins.yBinnedByFrame[i], null, null, null).getProcessor().convertToFloat();
            FHT imageFFT = createPaddedFFTImage(nextImage, paddedSize);
            FHT crossCorrelationImage = computeCrossCorrelationImage(firstImageFFT, imageFFT);

            FloatProcessor lowResNextImage = (FloatProcessor) lowResRenderer.getRenderedImage(bins.xBinnedByFrame[i], bins.yBinnedByFrame[i], null, null, null).getProcessor().convertToFloat();
            FHT lowResImageFFT = createPaddedFFTImage(lowResNextImage, lowResPaddedSize);
            FHT lowResCrossCorrelationImage = computeCrossCorrelationImage(lowResFirstImageFFT, lowResImageFFT);

            if(saveCorrelationImages) {
                correlationImages.addSlice("", crossCorrelationImage);
            }

            //find maxima in low res image
            multiplyImageByGaussianMask(new Point2D.Double(-driftXofImage[i - 1]/magnification + lowResPaddedSize / 2, -driftYofImage[i - 1]/magnification + lowResPaddedSize / 2), lowResPaddedSize, lowResCrossCorrelationImage);
            Point2D.Double lowResMaximumCoords = CorrelationDriftEstimator.findMaxima(lowResCrossCorrelationImage);
            lowResMaximumCoords = CorrelationDriftEstimator.findMaximaWithSubpixelPrecision(lowResMaximumCoords, 11, lowResCrossCorrelationImage);

            //translate maxima coords from low res image to high res image
            Point2D.Double highResMaximumCoords = new Point2D.Double(
                    crossCorrelationImage.getWidth() / 2 + magnification * (lowResMaximumCoords.x - (lowResCrossCorrelationImage.getWidth() / 2)),
                    crossCorrelationImage.getHeight() / 2 + magnification * (lowResMaximumCoords.y - (lowResCrossCorrelationImage.getHeight() / 2)));
            //find maxima in high res image
            highResMaximumCoords = CorrelationDriftEstimator.findMaximaWithSubpixelPrecision(highResMaximumCoords, 11, crossCorrelationImage);

            driftXofImage[i] = (crossCorrelationImage.getWidth() / 2 - highResMaximumCoords.x);
            driftYofImage[i] = (crossCorrelationImage.getHeight() / 2 - highResMaximumCoords.y);
        }

        //scale
        for(int i = 0; i < driftXofImage.length; i++) {
            driftXofImage[i] = driftXofImage[i] / magnification;
            driftYofImage[i] = driftYofImage[i] / magnification;
        }

        //interpolate the drift using loess interpolator, or linear interpolation if not enough data for loess
        PolynomialSplineFunction xFunction;
        PolynomialSplineFunction yFunction;
        if(steps < 4) {
            LinearInterpolator interpolator = new LinearInterpolator();
            xFunction = addLinearExtrapolationToBorders(interpolator.interpolate(bins.binCenters, driftXofImage), bins.minFrame, bins.maxFrame);
            yFunction = addLinearExtrapolationToBorders(interpolator.interpolate(bins.binCenters, driftYofImage), bins.minFrame, bins.maxFrame);
        } else {
            ModifiedLoess interpolator = new ModifiedLoess(smoothingBandwidth, 2);
            xFunction = addLinearExtrapolationToBorders(interpolator.interpolate(bins.binCenters, driftXofImage), bins.minFrame, bins.maxFrame);
            yFunction = addLinearExtrapolationToBorders(interpolator.interpolate(bins.binCenters, driftYofImage), bins.minFrame, bins.maxFrame);
        }

        IJ.showStatus("");
        IJ.showProgress(1.0);
        return new CrossCorrelationDriftResults(correlationImages, xFunction, yFunction, bins.binCenters, driftXofImage, driftYofImage, 1 / magnification, bins.minFrame, bins.maxFrame, MoleculeDescriptor.Units.PIXEL);
    }

    private static FHT createPaddedFFTImage(FloatProcessor nextImage, int paddedSize) {
        FHT imageFFT = new FHT(Padding.PADDING_ZERO.padToBiggerSquare(nextImage, paddedSize));
        imageFFT.setShowProgress(false);
        imageFFT.transform();
        return imageFFT;
    }

    private static FHT computeCrossCorrelationImage(FHT image1FFT, FHT image2FFT) {
        FHT crossCorrelationImage = image1FFT.conjugateMultiply(image2FFT);
        crossCorrelationImage.setShowProgress(false);
        crossCorrelationImage.inverseTransform();
        crossCorrelationImage.swapQuadrants();
        return crossCorrelationImage;
    }

    private static void multiplyImageByGaussianMask(Point2D.Double gaussianCenter, double gaussianSigma, FloatProcessor image) {
        for(int y = 0; y < image.getHeight(); y++) {
            for(int x = 0; x < image.getWidth(); x++) {
                double maskValue = MathProxy.exp(-(MathProxy.sqr(x - gaussianCenter.x) + MathProxy.sqr(y - gaussianCenter.y)) / (2 * gaussianSigma * gaussianSigma));
                float newValue = (float) (image.getf(x, y) * maskValue);
                image.setf(x, y, newValue);
            }
        }
    }

    private static class BinningResults {

        double[][] xBinnedByFrame;
        double[][] yBinnedByFrame;
        double[] binCenters;
        int minFrame;
        int maxFrame;

        public BinningResults(double[][] xBinnedByFrame, double[][] yBinnedByFrame, double[] binCenters, int minFrame, int maxFrame) {
            this.xBinnedByFrame = xBinnedByFrame;
            this.yBinnedByFrame = yBinnedByFrame;
            this.binCenters = binCenters;
            this.minFrame = minFrame;
            this.maxFrame = maxFrame;
        }
    }

    private static BinningResults binResultByFrame(double[] x, double[] y, double[] frame, int binCount) {
        double minFrame = findMinFrame(frame);
        double maxFrame = findMaxFrame(frame);

        if(maxFrame == minFrame) {
            throw new RuntimeException("Requires multiple frames.");
        }

        MathArrays.sortInPlace(frame, x, y);
        int detectionsPerBin = frame.length / binCount;

        //alloc space for binned results
        double[][] xBinnedByFrame = new double[binCount][];
        double[][] yBinnedByFrame = new double[binCount][];
        double[] binCenters = new double[binCount];
        int currentPos = 0;
        for(int i = 0; i < binCount; i++) {
            int endPos = currentPos + detectionsPerBin;
            if(endPos >= frame.length || i == binCount - 1) {
                endPos = frame.length;
            } else {
                double frameAtEndPos = frame[endPos - 1];
                while(endPos < frame.length - 1 && frame[endPos] == frameAtEndPos) {
                    endPos++;
                }
            }
            if(currentPos > frame.length - 1) {
                xBinnedByFrame[i] = new double[0];
                yBinnedByFrame[i] = new double[0];
                binCenters[i] = maxFrame;
            } else {
                xBinnedByFrame[i] = Arrays.copyOfRange(x, currentPos, endPos);
                yBinnedByFrame[i] = Arrays.copyOfRange(y, currentPos, endPos);
                binCenters[i] = (frame[currentPos] + frame[endPos - 1]) / 2;
            }
            currentPos = endPos;
        }
        return new BinningResults(xBinnedByFrame, yBinnedByFrame, binCenters, (int) minFrame, (int) maxFrame);
    }

    private static double findMinFrame(double[] frame) {
        //find min and max frame
        double minFrame = frame[0];
        for(int i = 0; i < frame.length; i++) {
            if(frame[i] < minFrame) {
                minFrame = frame[i];
            }
        }
        return minFrame;
    }

    private static double findMaxFrame(double[] frame) {
        //find min and max frame
        double maxFrame = frame[0];
        for(int i = 0; i < frame.length; i++) {
            if(frame[i] > maxFrame) {
                maxFrame = frame[i];
            }
        }
        return maxFrame;
    }


    static Point2D.Double findMaxima(FloatProcessor crossCorrelationImage) {
        float[] pixels = (float[]) crossCorrelationImage.getPixels();
        int maxIndex = 0;
        float max = pixels[0];
        for(int i = 0; i < pixels.length; i++) {
            if(pixels[i] > max) {
                max = pixels[i];
                maxIndex = i;
            }
        }
        return new Point2D.Double(maxIndex % crossCorrelationImage.getWidth(), maxIndex / crossCorrelationImage.getWidth());
    }

    static Point2D.Double findMaximaWithSubpixelPrecision(Point2D.Double maximumCoords, int roiSize, FHT crossCorrelationImage) {
        double[] subImageData = new double[roiSize * roiSize];
        float[] pixels = (float[]) crossCorrelationImage.getPixels();
        int roiX = (int) maximumCoords.x - (roiSize - 1) / 2;
        int roiY = (int) maximumCoords.y - (roiSize - 1) / 2;

        if(isCloseToBorder((int) maximumCoords.x, (int) maximumCoords.y, (roiSize - 1) / 2, crossCorrelationImage)) {
            return maximumCoords;
        }

        for(int ys = roiY; ys < roiY + roiSize; ys++) {
            int offset1 = (ys - roiY) * roiSize;
            int offset2 = ys * crossCorrelationImage.getWidth() + roiX;
            for(int xs = 0; xs < roiSize; xs++) {
                subImageData[offset1++] = pixels[offset2++];
            }
        }

        SubImage subImage = new SubImage(roiSize, roiSize, null, null, subImageData, 0, 0);
        RadialSymmetryFitter radialSymmetryFitter = new RadialSymmetryFitter();
        Molecule psf = radialSymmetryFitter.fit(subImage);

        return new Point2D.Double((int) maximumCoords.x + psf.getX(), (int) maximumCoords.y + psf.getY());
    }

    //
    public static PolynomialSplineFunction addLinearExtrapolationToBorders(PolynomialSplineFunction spline, int minFrame, int maxFrame) {
        PolynomialFunction[] polynomials = spline.getPolynomials();
        double[] knots = spline.getKnots();

        boolean addToBeginning = knots[0] != minFrame;
        boolean addToEnd = knots[knots.length - 1] != maxFrame;
        int sizeIncrease = 0 + (addToBeginning ? 1 : 0) + (addToEnd ? 1 : 0);
        if(!addToBeginning && !addToEnd) {
            return spline; //do nothing
        }

        //construct new knots and polynomial arrays
        double[] newKnots = new double[knots.length + sizeIncrease];
        PolynomialFunction[] newPolynomials = new PolynomialFunction[polynomials.length + sizeIncrease];
        //add to beginning
        if(addToBeginning) {
            //add knot
            newKnots[0] = minFrame;
            System.arraycopy(knots, 0, newKnots, 1, knots.length);
            //add function
            double derivativeAtFirstKnot = polynomials[0].derivative().value(0);
            double valueAtFirstKnot = spline.value(knots[0]);
            PolynomialFunction beginningFunction = new PolynomialFunction(new double[]{valueAtFirstKnot - (knots[0] - minFrame) * derivativeAtFirstKnot, derivativeAtFirstKnot});
            newPolynomials[0] = beginningFunction;
            System.arraycopy(polynomials, 0, newPolynomials, 1, polynomials.length);
        } else {
            System.arraycopy(knots, 0, newKnots, 0, knots.length);
            System.arraycopy(polynomials, 0, newPolynomials, 0, polynomials.length);
        }
        //add to end
        if(addToEnd) {
            //add knot
            newKnots[newKnots.length - 1] = maxFrame;
            //add function
            double derivativeAtLastKnot = polynomials[polynomials.length - 1].polynomialDerivative().value(knots[knots.length - 1] - knots[knots.length - 2]);
            double valueAtLastKnot = spline.value(knots[knots.length - 1]);
            PolynomialFunction endFunction = new PolynomialFunction(new double[]{valueAtLastKnot, derivativeAtLastKnot});
            newPolynomials[newPolynomials.length - 1] = endFunction;
        }

        return new PolynomialSplineFunction(newKnots, newPolynomials);

    }

    static boolean isCloseToBorder(int x, int y, int subimageSize, ImageProcessor image) {
        if(x < subimageSize || x > image.getWidth() - subimageSize - 1) {
            return true;
        }
        if(y < subimageSize || y > image.getHeight() - subimageSize - 1) {
            return true;
        }
        return false;
    }
}