/*
 * Copyright 2015, 2019 StreamEx contributors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package one.util.streamex;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Random;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntConsumer;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.ComparisonFailure;

import static one.util.streamex.Internals.TailSpliterator;
import static one.util.streamex.Internals.finished;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * @author Tagir Valeev
 */
public class TestHelpers {

    enum Mode {
        NORMAL, SPLITERATOR, PARALLEL, APPEND, PREPEND, RANDOM
    }

    static class StreamSupplier<T> {
        final Mode mode;
        final Supplier<Stream<T>> base;

        public StreamSupplier(Supplier<Stream<T>> base, Mode mode) {
            this.base = base;
            this.mode = mode;
        }

        public Stream<T> get() {
            Stream<T> res = base.get();
            switch (mode) {
            case NORMAL:
            case SPLITERATOR: 
                return res.sequential();
            case PARALLEL:
                return res.parallel();
            case APPEND:
                // using Stream.empty() or Arrays.asList() here is optimized out
                // in append/prepend which is undesired
                return StreamEx.of(res.parallel()).append(new ConcurrentLinkedQueue<>());
            case PREPEND:
                return StreamEx.of(res.parallel()).prepend(new ConcurrentLinkedQueue<>());
            case RANDOM:
                return StreamEx.of(new EmptyingSpliterator<>(res.parallel().spliterator())).parallel();
            default:
                throw new InternalError("Unsupported mode: " + mode);
            }
        }

        @Override
        public String toString() {
            return mode.toString();
        }
    }

    static class StreamExSupplier<T> extends StreamSupplier<T> {

        public StreamExSupplier(Supplier<Stream<T>> base, Mode mode) {
            super(base, mode);
        }

        @Override
        public StreamEx<T> get() {
            if (mode == Mode.SPLITERATOR)
                return StreamEx.of(base.get().spliterator());
            return StreamEx.of(super.get());
        }
    }

    static class EntryStreamSupplier<K, V> extends StreamSupplier<Map.Entry<K, V>> {

        public EntryStreamSupplier(Supplier<Stream<Map.Entry<K, V>>> base, Mode mode) {
            super(base, mode);
        }

        @Override
        public EntryStream<K, V> get() {
            return EntryStream.of(super.get());
        }
    }

    static <T> List<StreamExSupplier<T>> streamEx(Supplier<Stream<T>> base) {
        return StreamEx.of(Mode.values()).map(mode -> new StreamExSupplier<>(base, mode)).toList();
    }

    /**
     * Run the consumer once feeding it with RNG initialized with auto-generated seed
     * adding the seed value to every failed assertion message
     * 
     * @param cons consumer to run
     */
    static void withRandom(Consumer<Random> cons) {
        long seed = ThreadLocalRandom.current().nextLong();
        withRandom(seed, cons);
    }

    /**
     * Run the consumer once feeding it with RNG initialized with given seed
     * adding the seed value to every failed assertion message
     * 
     * @param seed random seed to use
     * @param cons consumer to run
     */
    static void withRandom(long seed, Consumer<Random> cons) {
        Random random = new Random(seed);
        withMessage("Using new Random(" + seed + ")", () -> cons.accept(random));
    }
    
    /**
     * Run the runnable automatically adding given message to every failed assertion
     * 
     * @param message message to prepend
     * @param r runnable to run
     */
    private static void withMessage(String message, Runnable r) {
        try {
            r.run();
        } catch (ComparisonFailure cmp) {
            ComparisonFailure ex = new ComparisonFailure(message + ": " + cmp.getMessage(), cmp.getExpected(), cmp
                    .getActual());
            ex.setStackTrace(cmp.getStackTrace());
            throw ex;
        } catch (AssertionError err) {
            AssertionError ex = new AssertionError(message + ": " + err.getMessage(), err.getCause());
            ex.setStackTrace(err.getStackTrace());
            throw ex;
        } catch (RuntimeException | Error err) {
            throw new RuntimeException(message + ": " + err.getMessage(), err);
        }
    }

    static void repeat(int times, IntConsumer consumer) {
        for (int i = 1; i <= times; i++) {
            int finalI = i;
            withMessage("#" + i, () -> consumer.accept(finalI));
        }
    }

    static <T> void streamEx(Supplier<Stream<T>> base, Consumer<StreamExSupplier<T>> consumer) {
        for (StreamExSupplier<T> supplier : StreamEx.of(Mode.values()).map(mode -> new StreamExSupplier<>(base, mode))) {
            withMessage(supplier.toString(), () -> consumer.accept(supplier));
        }
    }

    static <T> void emptyStreamEx(Class<T> clazz, Consumer<StreamExSupplier<T>> consumer) {
        streamEx(Stream::<T>empty, consumer);
    }

    static <K, V> void entryStream(Supplier<Stream<Map.Entry<K, V>>> base, Consumer<EntryStreamSupplier<K, V>> consumer) {
        for (EntryStreamSupplier<K, V> supplier : StreamEx.of(Mode.values()).map(
            mode -> new EntryStreamSupplier<>(base, mode))) {
            withMessage(supplier.toString(), () -> consumer.accept(supplier));
        }
    }

    /**
     * Spliterator which randomly inserts empty spliterators on splitting
     * 
     * @author Tagir Valeev
     *
     * @param <T> type of the elements
     */
    private static class EmptyingSpliterator<T> implements Spliterator<T> {
        private Spliterator<T> source;

        public EmptyingSpliterator(Spliterator<T> source) {
            this.source = Objects.requireNonNull(source);
        }

        @Override
        public boolean tryAdvance(Consumer<? super T> action) {
            return source.tryAdvance(action);
        }

        @Override
        public void forEachRemaining(Consumer<? super T> action) {
            source.forEachRemaining(action);
        }

        @Override
        public Comparator<? super T> getComparator() {
            return source.getComparator();
        }

        @Override
        public Spliterator<T> trySplit() {
            Spliterator<T> source = this.source;
            switch (ThreadLocalRandom.current().nextInt(3)) {
            case 0:
                return Spliterators.emptySpliterator();
            case 1:
                this.source = Spliterators.emptySpliterator();
                return source;
            default:
                Spliterator<T> split = source.trySplit();
                return split == null ? null : new EmptyingSpliterator<>(split);
            }
        }

        @Override
        public long estimateSize() {
            return source.estimateSize();
        }

        @Override
        public int characteristics() {
            return source.characteristics();
        }
    }

    static <T, R> void checkCollectorEmpty(String message, R expected, Collector<T, ?, R> collector) {
        if (finished(collector) != null)
            checkShortCircuitCollector(message, expected, 0, Stream::empty, collector);
        else
            checkCollector(message, expected, Stream::empty, collector);
    }

    static <T, TT extends T, R> void checkShortCircuitCollector(String message, R expected,
            int expectedConsumedElements, Supplier<Stream<TT>> base, Collector<T, ?, R> collector) {
        checkShortCircuitCollector(message, expected, expectedConsumedElements, base, collector, false);
    }

    static <T, TT extends T, R> void checkShortCircuitCollector(String message, R expected,
            int expectedConsumedElements, Supplier<Stream<TT>> base, Collector<T, ?, R> collector, boolean skipIdentity) {
        assertNotNull(message, finished(collector));
        Collector<T, ?, R> withIdentity = Collectors.collectingAndThen(collector, Function.identity());
        for (StreamExSupplier<TT> supplier : streamEx(base)) {
            AtomicInteger counter = new AtomicInteger();
            assertEquals(message + ": " + supplier, expected, supplier.get().peek(t -> counter.incrementAndGet())
                    .collect(collector));
            if (!supplier.get().isParallel())
                assertEquals(message + ": " + supplier + ": consumed: ", expectedConsumedElements, counter.get());
            if (!skipIdentity)
                assertEquals(message + ": " + supplier, expected, supplier.get().collect(withIdentity));
        }
    }

    static <T, TT extends T, R> void checkCollector(String message, R expected, Supplier<Stream<TT>> base,
            Collector<T, ?, R> collector) {
        // use checkShortCircuitCollector for CancellableCollector
        assertNull(message, finished(collector));
        for (StreamExSupplier<TT> supplier : streamEx(base)) {
            assertEquals(message + ": " + supplier, expected, supplier.get().collect(collector));
        }
    }

    static <T> void checkSpliterator(String msg, Supplier<Spliterator<T>> supplier) {
        List<T> expected = new ArrayList<>();
        supplier.get().forEachRemaining(expected::add);
        checkSpliterator(msg, expected, supplier);
    }

    /*
     * Tests whether spliterators produced by given supplier produce the
     * expected result under various splittings
     * 
     * This test is single-threaded. Its behavior is randomized, but random seed
     * will be printed in case of failure, so the results could be reproduced
     */
    static <T> void checkSpliterator(String msg, List<T> expected, Supplier<Spliterator<T>> supplier) {
        List<T> seq = new ArrayList<>();

        // Test characteristics
        Spliterator<T> forCharacteristics = supplier.get();
        if (forCharacteristics.hasCharacteristics(Spliterator.SORTED)) {
            forCharacteristics.getComparator(); // must not fail
        }
        assertTrue(forCharacteristics.estimateSize() >= 0);

        // Test forEachRemaining
        Spliterator<T> sequential = supplier.get();
        sequential.forEachRemaining(seq::add);
        assertFalse(msg, sequential.tryAdvance(t -> fail(msg + ": Advance called with " + t)));
        sequential.forEachRemaining(t -> fail(msg + ": Advance called with " + t));
        assertEquals(msg, expected, seq);

        // Test tryAdvance
        seq.clear();
        sequential = supplier.get();
        while (true) {
            AtomicBoolean called = new AtomicBoolean();
            boolean res = sequential.tryAdvance(t -> {
                seq.add(t);
                called.set(true);
            });
            if (res != called.get()) {
                fail(msg
                    + (res ? ": Consumer not called, but spliterator returned true"
                            : ": Consumer called, but spliterator returned false"));
            }
            if (!res)
                break;
        }
        assertFalse(msg, sequential.tryAdvance(t -> fail(msg + ": Advance called with " + t)));
        assertEquals(msg, expected, seq);
        
        // Test TailSpliterator
        if (sequential instanceof TailSpliterator) {
            seq.clear();
            TailSpliterator.forEachWithTail(supplier.get(), seq::add);
            assertEquals(msg, expected, seq);
            seq.clear();
            sequential = supplier.get();
            while (sequential != null) {
                sequential = TailSpliterator.tryAdvanceWithTail(sequential, seq::add);
            }
        }
        assertEquals(msg, expected, seq);
        
        // Test advance+remaining
        for (int i = 1; i < Math.min(4, expected.size() - 1); i++) {
            seq.clear();
            sequential = supplier.get();
            for (int j = 0; j < i; j++) assertTrue(msg, sequential.tryAdvance(seq::add));
            sequential.forEachRemaining(seq::add);
            assertEquals(msg, expected, seq);
        }

        // Test trySplit
        withRandom(r -> {
            repeat(500, n -> {
                Spliterator<T> spliterator = supplier.get();
                List<Spliterator<T>> spliterators = new ArrayList<>();
                spliterators.add(spliterator);
                int p = r.nextInt(10) + 2;
                for (int i = 0; i < p; i++) {
                    int idx = r.nextInt(spliterators.size());
                    Spliterator<T> split = spliterators.get(idx).trySplit();
                    if (split != null)
                        spliterators.add(idx, split);
                }
                List<Integer> order = IntStreamEx.ofIndices(spliterators).boxed().toList();
                Collections.shuffle(order, r);
                List<T> list = StreamEx.of(order).mapToEntry(idx -> {
                    Spliterator<T> s = spliterators.get(idx);
                    Stream.Builder<T> builder = Stream.builder();
                    s.forEachRemaining(builder);
                    assertFalse(msg, s.tryAdvance(t -> fail(msg + ": Advance called with " + t)));
                    s.forEachRemaining(t -> fail(msg + ": Advance called with " + t));
                    return builder.build();
                }).sortedBy(Entry::getKey).values().flatMap(Function.identity()).toList();
                assertEquals(msg, expected, list);
            });
            repeat(500, n -> {
                Spliterator<T> spliterator = supplier.get();
                List<Spliterator<T>> spliterators = new ArrayList<>();
                spliterators.add(spliterator);
                int p = r.nextInt(30) + 2;
                for (int i = 0; i < p; i++) {
                    int idx = r.nextInt(spliterators.size());
                    Spliterator<T> split = spliterators.get(idx).trySplit();
                    if (split != null)
                        spliterators.add(idx, split);
                }
                List<List<T>> results = StreamEx.<List<T>>generate(ArrayList::new).limit(spliterators.size())
                        .toList();
                int count = spliterators.size();
                while (count > 0) {
                    int i;
                    do {
                        i = r.nextInt(spliterators.size());
                        spliterator = spliterators.get(i);
                    } while (spliterator == null);
                    if (!spliterator.tryAdvance(results.get(i)::add)) {
                        spliterators.set(i, null);
                        count--;
                    }
                }
                List<T> list = StreamEx.of(results).flatMap(List::stream).toList();
                assertEquals(msg, expected, list);
            });
        });
    }

    static <T> void consumeElement(Spliterator<T> spliterator, T element) {
        boolean[] consumed = {false};
        assertTrue(spliterator.tryAdvance(x -> {
            assertEquals(element, x);
            consumed[0] = true;
        }));
        assertTrue(consumed[0]);
    }

    static void checkIllegalStateException(Runnable r, String key, String value1, String value2) {
        try {
            r.run();
            fail("no exception");
        } catch (IllegalStateException ex) {
            String exmsg = ex.getMessage();
            if (!exmsg.equals("Duplicate entry for key '" + key + "' (attempt to merge values '" + value1 + "' and '"
                + value2 + "')")
                && !exmsg.equals("Duplicate entry for key '" + key + "' (attempt to merge values '" + value2
                    + "' and '" + value1 + "')")
                && !exmsg.equals("java.lang.IllegalStateException: Duplicate entry for key '" + key
                    + "' (attempt to merge values '" + value1 + "' and '" + value2 + "')")
                && !exmsg.equals("java.lang.IllegalStateException: Duplicate entry for key '" + key
                    + "' (attempt to merge values '" + value2 + "' and '" + value1 + "')"))
                fail("wrong exception message: " + exmsg);
        }
    }

    @FunctionalInterface
    interface Statement {
        void evaluate() throws Throwable;
    }

    static void assertThrows(Class<? extends Throwable> expected, Statement statement) {
        assertThrows(expected, msg -> true, statement);
    }

    static void assertThrows(Class<? extends Throwable> expected, Predicate<? super String> checkExceptionAction,
                             Statement statement) {
        try {
            statement.evaluate();
        } catch (Throwable e) {
            if (!expected.isAssignableFrom(e.getClass())) {
                throw new AssertionError("Unexpected exception, " +
                        "expected<" + expected.getName() + "> " +
                        "but was<" + e.getClass().getName() + ">", e);
            }
            if (!checkExceptionAction.test(e.getMessage())) {
                fail("Unexpected exception<" + e.getMessage() + ">");
            }
            return;
        }
        fail("Expected exception: " + expected.getName());
    }

    static <T> Spliterator<T> emptySpliteratorWithExactSize(long exactSize) {
        return new Spliterators.AbstractSpliterator<T>(0, Spliterator.SIZED) {

            @Override
            public long getExactSizeIfKnown() {
                return exactSize;
            }

            @Override
            public boolean tryAdvance(Consumer<? super T> ignored) {
                return false;
            }
        };
    }
}