/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.eval;

import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;

import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;


public class EvalJsonTest extends BaseDL4JTest {

    @Test
    public void testSerdeEmpty() {
        boolean print = false;

        org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
                        new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(),
                        new EvaluationCalibration()};

        for (org.nd4j.evaluation.IEvaluation e : arr) {
            String json = e.toJson();
            String stats = e.stats();
            if (print) {
                System.out.println(e.getClass() + "\n" + json + "\n\n");
            }

            IEvaluation fromJson = (IEvaluation) org.nd4j.evaluation.BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
            assertEquals(e.toJson(), fromJson.toJson());
        }
    }

    @Test
    public void testSerde() {
        boolean print = false;
        Nd4j.getRandom().setSeed(12345);

        Evaluation evaluation = new Evaluation();
        EvaluationBinary evaluationBinary = new EvaluationBinary();
        ROC roc = new ROC(2);
        ROCBinary roc2 = new ROCBinary(2);
        ROCMultiClass roc3 = new ROCMultiClass(2);
        RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
        EvaluationCalibration ec = new EvaluationCalibration();


        org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};

        INDArray evalLabel = Nd4j.create(10, 3);
        for (int i = 0; i < 10; i++) {
            evalLabel.putScalar(i, i % 3, 1.0);
        }
        INDArray evalProb = Nd4j.rand(10, 3);
        evalProb.diviColumnVector(evalProb.sum(true,1));
        evaluation.eval(evalLabel, evalProb);
        roc3.eval(evalLabel, evalProb);
        ec.eval(evalLabel, evalProb);

        evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
        evalProb = Nd4j.rand(10, 3);
        evaluationBinary.eval(evalLabel, evalProb);
        roc2.eval(evalLabel, evalProb);

        evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
        evalProb = Nd4j.rand(10, 1);
        roc.eval(evalLabel, evalProb);

        regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));



        for (org.nd4j.evaluation.IEvaluation e : arr) {
            String json = e.toJson();
            if (print) {
                System.out.println(e.getClass() + "\n" + json + "\n\n");
            }

            IEvaluation fromJson = (IEvaluation) BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
            assertEquals(e.toJson(), fromJson.toJson());
        }
    }

    @Test
    public void testSerdeExactRoc() {
        Nd4j.getRandom().setSeed(12345);
        boolean print = false;

        ROC roc = new ROC(0);
        ROCBinary roc2 = new ROCBinary(0);
        ROCMultiClass roc3 = new ROCMultiClass(0);


        org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {roc, roc2, roc3};

        INDArray evalLabel = Nd4j.create(100, 3);
        for (int i = 0; i < 100; i++) {
            evalLabel.putScalar(i, i % 3, 1.0);
        }
        INDArray evalProb = Nd4j.rand(100, 3);
        evalProb.diviColumnVector(evalProb.sum(1));
        roc3.eval(evalLabel, evalProb);

        evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5));
        evalProb = Nd4j.rand(100, 3);
        roc2.eval(evalLabel, evalProb);

        evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
        evalProb = Nd4j.rand(100, 1);
        roc.eval(evalLabel, evalProb);

        for (org.nd4j.evaluation.IEvaluation e : arr) {
            System.out.println(e.getClass());
            String json = e.toJson();
            String stats = e.stats();
            if (print) {
                System.out.println(json + "\n\n");
            }
            org.nd4j.evaluation.IEvaluation fromJson = BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
            assertEquals(e, fromJson);

            if (fromJson instanceof ROC) {
                //Shouldn't have probAndLabel, but should have stored AUC and AUPRC
                assertNull(((ROC) fromJson).getProbAndLabel());
                assertTrue(((ROC) fromJson).calculateAUC() > 0.0);
                assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0);

                assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve());
                assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve());
            } else if (e instanceof ROCBinary) {
                org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying();
                org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying();
                //                for(ROC r : rocs ){
                for (int i = 0; i < origRocs.length; i++) {
                    org.nd4j.evaluation.classification.ROC r = rocs[i];
                    org.nd4j.evaluation.classification.ROC origR = origRocs[i];
                    //Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
                    assertNull(r.getProbAndLabel());
                    assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
                    assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
                    assertEquals(origR.getRocCurve(), origR.getRocCurve());
                    assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
                }

            } else if (e instanceof ROCMultiClass) {
                org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying();
                org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying();
                for (int i = 0; i < origRocs.length; i++) {
                    org.nd4j.evaluation.classification.ROC r = rocs[i];
                    org.nd4j.evaluation.classification.ROC origR = origRocs[i];
                    //Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
                    assertNull(r.getProbAndLabel());
                    assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
                    assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
                    assertEquals(origR.getRocCurve(), origR.getRocCurve());
                    assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
                }
            }
        }
    }

    @Test
    public void testJsonYamlCurves() {
        ROC roc = new ROC(0);

        INDArray evalLabel =
                        Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
        INDArray evalProb = Nd4j.rand(100, 1);
        roc.eval(evalLabel, evalProb);

        RocCurve c = roc.getRocCurve();
        PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();

        String json1 = c.toJson();
        String json2 = prc.toJson();

        RocCurve c2 = RocCurve.fromJson(json1);
        PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2);

        assertEquals(c, c2);
        assertEquals(prc, prc2);

        //        System.out.println(json1);

        //Also test: histograms

        EvaluationCalibration ec = new EvaluationCalibration();

        evalLabel = Nd4j.create(10, 3);
        for (int i = 0; i < 10; i++) {
            evalLabel.putScalar(i, i % 3, 1.0);
        }
        evalProb = Nd4j.rand(10, 3);
        evalProb.diviColumnVector(evalProb.sum(1));
        ec.eval(evalLabel, evalProb);

        Histogram[] histograms = new Histogram[] {ec.getResidualPlotAllClasses(), ec.getResidualPlot(0),
                        ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0),
                        ec.getProbabilityHistogram(1)};

        for (Histogram h : histograms) {
            String json = h.toJson();
            String yaml = h.toYaml();

            Histogram h2 = Histogram.fromJson(json);
            Histogram h3 = Histogram.fromYaml(yaml);

            assertEquals(h, h2);
            assertEquals(h2, h3);
        }

    }

    @Test
    public void testJsonWithCustomThreshold() {

        //Evaluation - binary threshold
        Evaluation e = new Evaluation(0.25);
        String json = e.toJson();
        String yaml = e.toYaml();

        Evaluation eFromJson = Evaluation.fromJson(json);
        Evaluation eFromYaml = Evaluation.fromYaml(yaml);

        assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6);
        assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6);


        //Evaluation: custom cost array
        INDArray costArray = Nd4j.create(new double[] {1.0, 2.0, 3.0});
        Evaluation e2 = new Evaluation(costArray);

        json = e2.toJson();
        yaml = e2.toYaml();

        eFromJson = Evaluation.fromJson(json);
        eFromYaml = Evaluation.fromYaml(yaml);

        assertEquals(e2.getCostArray(), eFromJson.getCostArray());
        assertEquals(e2.getCostArray(), eFromYaml.getCostArray());



        //EvaluationBinary - per-output binary threshold
        INDArray threshold = Nd4j.create(new double[] {1.0, 0.5, 0.25});
        EvaluationBinary eb = new EvaluationBinary(threshold);

        json = eb.toJson();
        yaml = eb.toYaml();

        EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json);
        EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml);

        assertEquals(threshold, ebFromJson.getDecisionThreshold());
        assertEquals(threshold, ebFromYaml.getDecisionThreshold());

    }

}