/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.rng.sampling;

import java.util.Set;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.ArrayList;
import java.util.Collection;

import org.junit.Assert;
import org.junit.Test;

import org.apache.commons.math3.stat.inference.ChiSquareTest;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;

/**
 * Tests for {@link ListSampler}.
 */
public class ListSamplerTest {
    private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 6543432321L);
    private final ChiSquareTest chiSquareTest = new ChiSquareTest();

    @Test
    public void testSample() {
        final String[][] c = {{"0", "1"}, {"0", "2"}, {"0", "3"}, {"0", "4"},
                              {"1", "2"}, {"1", "3"}, {"1", "4"},
                              {"2", "3"}, {"2", "4"},
                              {"3", "4"}};
        final long[] observed = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
        final double[] expected = {100, 100, 100, 100, 100, 100, 100, 100, 100, 100};

        final HashSet<String> cPop = new HashSet<String>(); // {0, 1, 2, 3, 4}.
        for (int i = 0; i < 5; i++) {
            cPop.add(Integer.toString(i));
        }

        final List<Set<String>> sets = new ArrayList<Set<String>>(); // 2-sets from 5.
        for (int i = 0; i < 10; i++) {
            final HashSet<String> hs = new HashSet<String>();
            hs.add(c[i][0]);
            hs.add(c[i][1]);
            sets.add(hs);
        }

        for (int i = 0; i < 1000; i++) {
            observed[findSample(sets, ListSampler.sample(rng, new ArrayList<String>(cPop), 2))]++;
        }

        // Pass if we cannot reject null hypothesis that distributions are the same.
        Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
    }

    @Test
    public void testSampleWhole() {
        // Sample of size = size of collection must return the same collection.
        final List<String> list = new ArrayList<String>();
        list.add("one");

        final List<String> one = ListSampler.sample(rng, list, 1);
        Assert.assertEquals(1, one.size());
        Assert.assertTrue(one.contains("one"));
    }

    @Test(expected = IllegalArgumentException.class)
    public void testSamplePrecondition1() {
        // Must fail for sample size > collection size.
        final List<String> list = new ArrayList<String>();
        list.add("one");
        ListSampler.sample(rng, list, 2);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testSamplePrecondition2() {
        // Must fail for empty collection.
        final List<String> list = new ArrayList<String>();
        ListSampler.sample(rng, list, 1);
    }

    @Test
    public void testShuffle() {
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < 10; i++) {
            orig.add((i + 1) * rng.nextInt());
        }

        final List<Integer> arrayList = new ArrayList<Integer>(orig);

        ListSampler.shuffle(rng, arrayList);
        // Ensure that at least one entry has moved.
        Assert.assertTrue("ArrayList", compare(orig, arrayList, 0, orig.size(), false));

        final List<Integer> linkedList = new LinkedList<Integer>(orig);

        ListSampler.shuffle(rng, linkedList);
        // Ensure that at least one entry has moved.
        Assert.assertTrue("LinkedList", compare(orig, linkedList, 0, orig.size(), false));
    }

    @Test
    public void testShuffleTail() {
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < 10; i++) {
            orig.add((i + 1) * rng.nextInt());
        }
        final List<Integer> list = new ArrayList<Integer>(orig);

        final int start = 4;
        ListSampler.shuffle(rng, list, start, false);

        // Ensure that all entries below index "start" did not move.
        Assert.assertTrue(compare(orig, list, 0, start, true));

        // Ensure that at least one entry has moved.
        Assert.assertTrue(compare(orig, list, start, orig.size(), false));
    }

    @Test
    public void testShuffleHead() {
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < 10; i++) {
            orig.add((i + 1) * rng.nextInt());
        }
        final List<Integer> list = new ArrayList<Integer>(orig);

        final int start = 4;
        ListSampler.shuffle(rng, list, start, true);

        // Ensure that all entries above index "start" did not move.
        Assert.assertTrue(compare(orig, list, start + 1, orig.size(), true));

        // Ensure that at least one entry has moved.
        Assert.assertTrue(compare(orig, list, 0, start + 1, false));
    }

    /**
     * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
     * The implementation may be different but the result is a Fisher-Yates shuffle so the
     * output order should match.
     */
    @Test
    public void testShuffleMatchesPermutationSamplerShuffle() {
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < 10; i++) {
            orig.add((i + 1) * rng.nextInt());
        }

        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig));
        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig));
    }

    /**
     * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
     * The implementation may be different but the result is a Fisher-Yates shuffle so the
     * output order should match.
     */
    @Test
    public void testShuffleMatchesPermutationSamplerShuffleDirectional() {
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < 10; i++) {
            orig.add((i + 1) * rng.nextInt());
        }

        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, true);
        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, false);
        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, true);
        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, false);
    }

    /**
     * This test hits the edge case when a LinkedList is small enough that the algorithm
     * using a RandomAccess list is faster than the one with an iterator.
     */
    @Test
    public void testShuffleWithSmallLinkedList() {
        final int size = 3;
        final List<Integer> orig = new ArrayList<Integer>();
        for (int i = 0; i < size; i++) {
            orig.add((i + 1) * rng.nextInt());
        }

        // When the size is small there is a chance that the list has no entries that move.
        // E.g. The number of permutations of 3 items is only 6 giving a 1/6 chance of no change.
        // So repeat test that the small shuffle matches the PermutationSampler.
        // 10 times is (1/6)^10 or 1 in 60,466,176 of no change.
        for (int i = 0; i < 10; i++) {
            assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), size - 1, true);
        }
    }

    //// Support methods.

    /**
     * If {@code same == true}, return {@code true} if all entries are
     * the same; if {@code same == false}, return {@code true} if at
     * least one entry is different.
     */
    private <T> boolean compare(List<T> orig,
                                List<T> list,
                                int start,
                                int end,
                                boolean same) {
        for (int i = start; i < end; i++) {
            if (!orig.get(i).equals(list.get(i))) {
                return same ? false : true;
            }
        }
        return same ? true : false;
    }

    private <T extends Set<String>> int findSample(List<T> u,
                                                   Collection<String> sampList) {
        final String[] samp = sampList.toArray(new String[sampList.size()]);
        for (int i = 0; i < u.size(); i++) {
            final T set = u.get(i);
            final HashSet<String> sampSet = new HashSet<String>();
            for (int j = 0; j < samp.length; j++) {
                sampSet.add(samp[j]);
            }
            if (set.equals(sampSet)) {
                return i;
            }
        }

        Assert.fail("Sample not found: { " +
                    samp[0] + ", " + samp[1] + " }");
        return -1;
    }

    /**
     * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
     *
     * @param list Array whose entries will be shuffled (in-place).
     */
    private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list) {
        final int[] array = new int[list.size()];
        ListIterator<Integer> it = list.listIterator();
        for (int i = 0; i < array.length; i++) {
            array[i] = it.next();
        }

        // Identical RNGs
        final long seed = RandomSource.createLong();
        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);

        ListSampler.shuffle(rng1, list);
        PermutationSampler.shuffle(rng2, array);

        final String msg = "Type=" + list.getClass().getSimpleName();
        it = list.listIterator();
        for (int i = 0; i < array.length; i++) {
            Assert.assertEquals(msg, array[i], it.next().intValue());
        }
    }
    /**
     * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
     *
     * @param list Array whose entries will be shuffled (in-place).
     * @param start Index at which shuffling begins.
     * @param towardHead Shuffling is performed for index positions between
     * {@code start} and either the end (if {@code false}) or the beginning
     * (if {@code true}) of the array.
     */
    private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list,
                                                                    int start,
                                                                    boolean towardHead) {
        final int[] array = new int[list.size()];
        ListIterator<Integer> it = list.listIterator();
        for (int i = 0; i < array.length; i++) {
            array[i] = it.next();
        }

        // Identical RNGs
        final long seed = RandomSource.createLong();
        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);

        ListSampler.shuffle(rng1, list, start, towardHead);
        PermutationSampler.shuffle(rng2, array, start, towardHead);

        final String msg = String.format("Type=%s start=%d towardHead=%b",
                list.getClass().getSimpleName(), start, towardHead);
        it = list.listIterator();
        for (int i = 0; i < array.length; i++) {
            Assert.assertEquals(msg, array[i], it.next().intValue());
        }
    }
}