/*
Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License"). 
You may not use this file except in compliance with the License. 
A copy of the License is located at

   http://aws.amazon.com/apache2.0/

or in the "license" file accompanying this file. 
This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.
*/
package com.amazonaws.kinesisvideo.parser.utilities;

import com.amazonaws.kinesisvideo.parser.TestResourceUtil;
import com.amazonaws.kinesisvideo.parser.ebml.EBMLTypeInfo;
import com.amazonaws.kinesisvideo.parser.ebml.InputStreamParserByteSource;
import com.amazonaws.kinesisvideo.parser.ebml.MkvTypeInfos;
import com.amazonaws.kinesisvideo.parser.mkv.StreamingMkvReader;
import com.amazonaws.kinesisvideo.parser.mkv.visitors.CountVisitor;
import com.amazonaws.kinesisvideo.parser.mkv.visitors.ElementSizeAndOffsetVisitor;
import com.amazonaws.kinesisvideo.parser.mkv.MkvElement;
import com.amazonaws.kinesisvideo.parser.mkv.MkvElementVisitException;
import org.apache.commons.lang3.time.StopWatch;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

/**
 * Tests for OutputSegmentMerger
 */
public class OutputSegmentMergerTest {

    /**
     * This test merges the separate Mkv chunks generated by Kinesis Video GetMedia into one stream
     * as long as the chunks have the same EBML header and tracks.
     * It does a few things:
     * 1.Reads output_get_media.mkv that contains the output of get media call with 32 chunks as 32 Mkvstreaams.
     * 2.Merges it into one stream with 32 mkv clusters (fragments)
     * 3.It parses the merged stream to count the number of ebml headers, segments, clusters using
     * the {@link CountVisitor}.
     * Validates that the right number are present.
     * 4.Writes the merged output to a tmp file mergedoutput.mkv. mkvinfo can parse this successfully unliked the source
     * output_get_media.mkv.
     * 5.The test then parses the merged output again to print out the elements, their offsets in the merged mkv and
     * size of the element in bytes. It uses the {@link ElementSizeAndOffsetVisitor} to do this.
     *
     * @throws IOException
     * @throws InterruptedException
     */
    @Test
    public void mergeTracksAndEBML() throws IOException, InterruptedException, MkvElementVisitException {
        List<EBMLTypeInfo> typeInfosToMergeOn = new ArrayList<>();
        typeInfosToMergeOn.add(MkvTypeInfos.TRACKS);
        typeInfosToMergeOn.add(MkvTypeInfos.EBML);

        //Test that the merge works correctly.
        byte [] outputBytes = mergeTestInternal(typeInfosToMergeOn);

        //TODO: enable to write the merged output to a file.
       /* Path tmpFileName = Files.createTempFile("OutputSegmentMergerMergeTracksAndEBML", "mergedoutput.mkv");
        Files.write(tmpFileName, outputBytes,
                StandardOpenOption.WRITE,
                StandardOpenOption.CREATE);
        */
        //Write out the element id, offset and data sizes of the various elements.
        writeOutIdAndOffset(outputBytes);
    }

    @Test
    public void mergeWithTimeCodeBackwards()
            throws IOException, InterruptedException, MkvElementVisitException {
        //Read all the inputBytes so that we can compare with output bytes later.
        final byte [] inputBytes = TestResourceUtil.getTestInputByteArray("output_get_media.mkv");
        final InputStream in = getInputStreamForDoubleBytes(inputBytes);

        //Stream to receive the merged output.
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

        //Do the actual merge.
        OutputSegmentMerger merger =
                OutputSegmentMerger.createDefault(outputStream);

        StreamingMkvReader mkvStreamReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(in));
        while(mkvStreamReader.mightHaveNext()) {
            Optional<MkvElement> mkvElement = mkvStreamReader.nextIfAvailable();
            if (mkvElement.isPresent()) {
                mkvElement.get().accept(merger);
            }
        }

        final byte[] outputBytes = outputStream.toByteArray();
        Assert.assertFalse(Arrays.equals(inputBytes, outputBytes));

        //Count different types of elements present in the merged stream.
        CountVisitor countVisitor = getCountVisitorResult(outputBytes);

        //Validate that there are two EBML headers and segment and tracks
        //but there are 64 clusters and tracks as expected.
        Assert.assertEquals(2, countVisitor.getCount(MkvTypeInfos.EBML));
        Assert.assertEquals(2, countVisitor.getCount(MkvTypeInfos.EBMLVERSION));
        Assert.assertEquals(2, countVisitor.getCount(MkvTypeInfos.SEGMENT));
        Assert.assertEquals(10, countVisitor.getCount(MkvTypeInfos.CLUSTER));
        Assert.assertEquals(10, countVisitor.getCount(MkvTypeInfos.TIMECODE));
        Assert.assertEquals(2, countVisitor.getCount(MkvTypeInfos.TRACKS));
        Assert.assertEquals(2, countVisitor.getCount(MkvTypeInfos.TRACKNUMBER));
        Assert.assertEquals(600, countVisitor.getCount(MkvTypeInfos.SIMPLEBLOCK));
        Assert.assertEquals(120, countVisitor.getCount(MkvTypeInfos.TAGNAME));
    }

    @Test
    public void stopWithTimeCodeBackwards()
            throws IOException, InterruptedException, MkvElementVisitException {
        //Read all the inputBytes so that we can compare with output bytes later.
        String fileName = "output-get-media-non-increasing-timecode.mkv";
        CountVisitor countVisitor = runMergerToStopAtFirstNonMatchingSegment(fileName);

        //Validate that there is only one EBML header and segment and tracks,
        //only 1 cluster and other elements as expected
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBML));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBMLVERSION));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.SEGMENT));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.CLUSTER));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TIMECODE));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKS));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKNUMBER));
        Assert.assertEquals(30, countVisitor.getCount(MkvTypeInfos.SIMPLEBLOCK));
        Assert.assertEquals(59, countVisitor.getCount(MkvTypeInfos.TAGNAME));
    }

    @Test
    public void stopWithTimeCodeEqual()
            throws IOException, InterruptedException, MkvElementVisitException {
        //Read all the inputBytes so that we can compare with output bytes later.
        String fileName = "output-get-media-equal-timecode.mkv";
        CountVisitor countVisitor = runMergerToStopAtFirstNonMatchingSegment(fileName);

        //Validate that there is only one EBML header and segment and tracks,
        //only 1 cluster and other elements as expected
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBML));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBMLVERSION));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.SEGMENT));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.CLUSTER));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TIMECODE));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKS));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKNUMBER));
        Assert.assertEquals(120, countVisitor.getCount(MkvTypeInfos.SIMPLEBLOCK));
        Assert.assertEquals(12, countVisitor.getCount(MkvTypeInfos.TAGNAME));
    }



    private CountVisitor runMergerToStopAtFirstNonMatchingSegment(String fileName)
            throws IOException, MkvElementVisitException {
        final byte [] inputBytes = TestResourceUtil.getTestInputByteArray(fileName);
        final InputStream in = getInputStreamForDoubleBytes(inputBytes);

        //Stream to receive the merged output.
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

        //Do the actual merge.
        OutputSegmentMerger merger =
                OutputSegmentMerger.createToStopAtFirstNonMatchingSegment(outputStream);

        StreamingMkvReader mkvStreamReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(in));
        while(!merger.isDone() && mkvStreamReader.mightHaveNext()) {
            Optional<MkvElement> mkvElement = mkvStreamReader.nextIfAvailable();
            if (mkvElement.isPresent()) {
                mkvElement.get().accept(merger);
            }
        }

        Assert.assertTrue(merger.isDone());

        final byte[] outputBytes = outputStream.toByteArray();
        Assert.assertFalse(Arrays.equals(inputBytes, outputBytes));

        //Count different types of elements present in the merged stream.
        return getCountVisitorResult(outputBytes);
    }

    @Test
    public void mergeWithStopAfterFirstSegment()
            throws IOException, InterruptedException, MkvElementVisitException {

        //Read all the inputBytes so that we can compare with output bytes later.
        CountVisitor countVisitor = runMergerToStopAtFirstNonMatchingSegment("output_get_media.mkv");

        //Validate that there is only one EBML header and segment and tracks
        //but there are 32 clusters and tracks as expected.
        assertCountsAfterMerge(countVisitor);
    }

    private InputStream getInputStreamForDoubleBytes(byte[] inputBytes) throws IOException {
        ByteArrayOutputStream doubleStream = new ByteArrayOutputStream();
        doubleStream.write(inputBytes);
        doubleStream.write(inputBytes);

        //Reading again purely to show that the OutputSegmentMerger works even with streams
        //where all the data is not in memeory.
        return new ByteArrayInputStream(doubleStream.toByteArray());
    }

    private byte [] mergeTestInternal(List<EBMLTypeInfo> typeInfosToMergeOn)
            throws IOException, InterruptedException, MkvElementVisitException {
        //Read all the inputBytes so that we can compare with output bytes later.
        final byte [] inputBytes = TestResourceUtil.getTestInputByteArray("output_get_media.mkv");


        //Reading again purely to show that the OutputSegmentMerger works even with streams
        //where all the data is not in memeory.
        final InputStream in = TestResourceUtil.getTestInputStream("output_get_media.mkv");

        //Stream to receive the merged output.
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

        //Do the actual merge.
        OutputSegmentMerger merger =
                OutputSegmentMerger.createDefault(outputStream);

        StreamingMkvReader mkvStreamReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(in));
        while(mkvStreamReader.mightHaveNext()) {
            Optional<MkvElement> mkvElement = mkvStreamReader.nextIfAvailable();
            if (mkvElement.isPresent()) {
                mkvElement.get().accept(merger);
            }
        }

        final byte []outputBytes = outputStream.toByteArray();
        Assert.assertFalse(Arrays.equals(inputBytes, outputBytes));

        //Count different types of elements present in the merged stream.
        CountVisitor countVisitor = getCountVisitorResult(outputBytes);

        //Validate that there is only one EBML header and segment and tracks
        //but there are 5 clusters and tracks as expected.
        assertCountsAfterMerge(countVisitor);

        return outputBytes;
    }

    private void assertCountsAfterMerge(CountVisitor countVisitor) {
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBML));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.EBMLVERSION));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.SEGMENT));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.CLUSTER));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.TIMECODE));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKS));
        Assert.assertEquals(1, countVisitor.getCount(MkvTypeInfos.TRACKNUMBER));
        Assert.assertEquals(300, countVisitor.getCount(MkvTypeInfos.SIMPLEBLOCK));
        Assert.assertEquals(60, countVisitor.getCount(MkvTypeInfos.TAGNAME));
    }

    private CountVisitor getCountVisitorResult(byte[] outputBytes) throws MkvElementVisitException {
        ByteArrayInputStream verifyStream = new ByteArrayInputStream(outputBytes);

        //List of elements to count.
        List<EBMLTypeInfo> typesToCount = new ArrayList<>();
        typesToCount.add(MkvTypeInfos.EBML);
        typesToCount.add(MkvTypeInfos.EBMLVERSION);
        typesToCount.add(MkvTypeInfos.SEGMENT);
        typesToCount.add(MkvTypeInfos.CLUSTER);
        typesToCount.add(MkvTypeInfos.TIMECODE);
        typesToCount.add(MkvTypeInfos.SIMPLEBLOCK);
        typesToCount.add(MkvTypeInfos.TRACKS);
        typesToCount.add(MkvTypeInfos.TRACKNUMBER);
        typesToCount.add(MkvTypeInfos.TAGNAME);

        //Create a visitor that counts the occurrences of the element.
        CountVisitor countVisitor = new CountVisitor(typesToCount);
        StreamingMkvReader verifyStreamReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(verifyStream));

        //Run the visitor over the stream.
        while(verifyStreamReader.mightHaveNext()) {
            Optional<MkvElement> mkvElement = verifyStreamReader.nextIfAvailable();
            if (mkvElement.isPresent()) {
                mkvElement.get().accept(countVisitor);
            }
        }

        Assert.assertTrue(countVisitor.doEndAndStartMasterElementsMatch());
        return countVisitor;
    }

    @Ignore
    @Test
    public void perfTest() throws IOException, MkvElementVisitException, InterruptedException {
        final byte [] inputBytes = TestResourceUtil.getTestInputByteArray("output_get_media.mkv");
        int numIterations = 1000;

        StopWatch timer = new StopWatch();
        timer.start();
        for (int i = 0; i < numIterations; i++) {
            try (ByteArrayInputStream in = new ByteArrayInputStream(inputBytes);
                    ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
                OutputSegmentMerger merger =
                        OutputSegmentMerger.createDefault(outputStream);

                StreamingMkvReader mkvStreamReader =
                        StreamingMkvReader.createWithMaxContentSize(new InputStreamParserByteSource(in), 32000);
                while(mkvStreamReader.mightHaveNext()) {
                    Optional<MkvElement> mkvElement = mkvStreamReader.nextIfAvailable();
                    if (mkvElement.isPresent()) {
                        mkvElement.get().accept(merger);
                    }
                }
            }
        }
        timer.stop();
        long totalTimeMillis = timer.getTime();
        double totalTimeSeconds = totalTimeMillis/(double )TimeUnit.SECONDS.toMillis(1);
        double mergeRate = (double )(inputBytes.length)*numIterations/(totalTimeSeconds*1024*1024);
        System.out.println("Total time "+totalTimeMillis+" ms "+" Merging rate "+mergeRate+" MB/s");
    }

    private static CountVisitor getCountVisitor() {
        return CountVisitor.create(MkvTypeInfos.CLUSTER, MkvTypeInfos.SEGMENT, MkvTypeInfos.SIMPLEBLOCK);
    }

    @Test
    public void basicTest() throws IOException, InterruptedException, MkvElementVisitException {
        final byte [] inputBytes = TestResourceUtil.getTestInputByteArray("output_get_media.mkv");

        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        ByteArrayInputStream in = new ByteArrayInputStream(inputBytes);
        OutputSegmentMerger merger =
                new OutputSegmentMerger(outputStream, new ArrayList<>(), getCountVisitor(), false);

        StreamingMkvReader mkvStreamReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(in));
        while(mkvStreamReader.mightHaveNext()) {
            Optional<MkvElement> mkvElement = mkvStreamReader.nextIfAvailable();
            if (mkvElement.isPresent()) {
                mkvElement.get().accept(merger);
            }
        }

        Assert.assertEquals(5, merger.getClustersCount());
        Assert.assertEquals(5, merger.getSegmentsCount());
        Assert.assertEquals(300, merger.getSimpleBlocksCount());

        final byte []outputBytes = outputStream.toByteArray();
        Assert.assertTrue(Arrays.equals(inputBytes, outputBytes));

        CountVisitor countVisitor = getCountVisitorResult(outputBytes);
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.EBML));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.EBMLVERSION));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.SEGMENT));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.CLUSTER));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.TIMECODE));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.TRACKS));
        Assert.assertEquals(5, countVisitor.getCount(MkvTypeInfos.TRACKNUMBER));
        Assert.assertEquals(300, countVisitor.getCount(MkvTypeInfos.SIMPLEBLOCK));
        Assert.assertEquals(60, countVisitor.getCount(MkvTypeInfos.TAGNAME));
    }

    @Test
    public void mergeEBMLHeaders() throws IOException, InterruptedException, MkvElementVisitException {
        List<EBMLTypeInfo> typeInfosToMergeOn = new ArrayList<>();
        typeInfosToMergeOn.add(MkvTypeInfos.EBML);

        mergeTestInternal(typeInfosToMergeOn);
    }

    @Test
    public void mergeTracks() throws IOException, InterruptedException, MkvElementVisitException {
        List<EBMLTypeInfo> typeInfosToMergeOn = new ArrayList<>();
        typeInfosToMergeOn.add(MkvTypeInfos.TRACKS);

        mergeTestInternal(typeInfosToMergeOn);
    }



    private void writeOutIdAndOffset(byte[] outputBytes) throws IOException, MkvElementVisitException {
        ByteArrayInputStream offsetStream = new ByteArrayInputStream(outputBytes);
        StreamingMkvReader offsetReader =
                StreamingMkvReader.createDefault(new InputStreamParserByteSource(offsetStream));

        //Write the element name, offset and size to a file.
        Path tempFile = Files.createTempFile("Merger","offset");
        try (BufferedWriter writer = Files.newBufferedWriter(tempFile,
                StandardCharsets.US_ASCII,
                StandardOpenOption.WRITE,
                StandardOpenOption.CREATE)) {
            ElementSizeAndOffsetVisitor offsetVisitor = new ElementSizeAndOffsetVisitor(writer);

            while(offsetReader.mightHaveNext()) {
                Optional<MkvElement> mkvElement = offsetReader.nextIfAvailable();
                if (mkvElement.isPresent()) {
                    mkvElement.get().accept(offsetVisitor);
                }
            }
        } finally {
            Files.delete(tempFile);
        }
    }


}