/*
 * Copyright 2015, 2019 StreamEx contributors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package one.util.streamex;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.StreamSupport;

/**
 * @author Tagir Valeev
 */
/* package */abstract class CrossSpliterator<T, A> implements Spliterator<A> {
    long est;
    int splitPos;
    final Spliterator<T>[] spliterators;
    final Collection<T>[] collections;

    @SuppressWarnings("unchecked")
    CrossSpliterator(Collection<? extends Collection<T>> source) {
        this.splitPos = 0;
        long est = 1;
        try {
            for (Collection<T> c : source) {
                long size = c.size();
                est = StrictMath.multiplyExact(est, size);
            }
        } catch (ArithmeticException e) {
            est = Long.MAX_VALUE;
        }
        this.est = est;
        this.collections = source.toArray(new Collection[0]);
        this.spliterators = new Spliterator[collections.length];
    }

    abstract Spliterator<A> doSplit(long prefixEst, Spliterator<T>[] prefixSpliterators,
            Collection<T>[] prefixCollections);

    abstract void accumulate(int pos, T t);

    CrossSpliterator(long est, int splitPos, Spliterator<T>[] spliterators, Collection<T>[] collections) {
        this.est = est;
        this.splitPos = splitPos;
        this.spliterators = spliterators;
        this.collections = collections;
    }

    static final class Reducing<T, A> extends CrossSpliterator<T, A> {
        private A[] elements;
        private final BiFunction<A, ? super T, A> accumulator;

        @SuppressWarnings("unchecked")
        Reducing(Collection<? extends Collection<T>> source, A identity, BiFunction<A, ? super T, A> accumulator) {
            super(source);
            this.accumulator = accumulator;
            this.elements = (A[]) new Object[collections.length + 1];
            this.elements[0] = identity;
        }

        private Reducing(long est, int splitPos, BiFunction<A, ? super T, A> accumulator,
                Spliterator<T>[] spliterators, Collection<T>[] collections, A[] elements) {
            super(est, splitPos, spliterators, collections);
            this.accumulator = accumulator;
            this.elements = elements;
        }

        @Override
        public boolean tryAdvance(Consumer<? super A> action) {
            if (elements == null)
                return false;
            if (est < Long.MAX_VALUE && est > 0)
                est--;
            int l = collections.length;
            if (advance(l - 1)) {
                action.accept(elements[l]);
                return true;
            }
            elements = null;
            est = 0;
            return false;
        }

        @Override
        public void forEachRemaining(Consumer<? super A> action) {
            if (elements == null)
                return;
            int l = collections.length;
            A[] e = elements;
            while (advance(l - 1)) {
                action.accept(e[l]);
            }
            elements = null;
            est = 0;
        }

        @Override
        Spliterator<A> doSplit(long prefixEst, Spliterator<T>[] prefixSpliterators, Collection<T>[] prefixCollections) {
            return new Reducing<>(prefixEst, splitPos, accumulator, prefixSpliterators, prefixCollections, elements
                    .clone());
        }

        @Override
        void accumulate(int pos, T t) {
            elements[pos + 1] = accumulator.apply(elements[pos], t);
        }
    }

    static final class ToList<T> extends CrossSpliterator<T, List<T>> {
        private List<T> elements;

        @SuppressWarnings("unchecked")
        ToList(Collection<? extends Collection<T>> source) {
            super(source);
            this.elements = (List<T>) Arrays.asList(new Object[collections.length]);
        }

        private ToList(long est, int splitPos, Spliterator<T>[] spliterators, Collection<T>[] collections,
                List<T> elements) {
            super(est, splitPos, spliterators, collections);
            this.elements = elements;
        }

        @Override
        public boolean tryAdvance(Consumer<? super List<T>> action) {
            if (elements == null)
                return false;
            if (est < Long.MAX_VALUE && est > 0)
                est--;
            if (advance(collections.length - 1)) {
                action.accept(new ArrayList<>(elements));
                return true;
            }
            elements = null;
            est = 0;
            return false;
        }

        @Override
        public void forEachRemaining(Consumer<? super List<T>> action) {
            if (elements == null)
                return;
            List<T> e = elements;
            int l = collections.length - 1;
            while (advance(l)) {
                action.accept(new ArrayList<>(e));
            }
            elements = null;
            est = 0;
        }

        @Override
        Spliterator<List<T>> doSplit(long prefixEst, Spliterator<T>[] prefixSpliterators,
                Collection<T>[] prefixCollections) {
            @SuppressWarnings("unchecked")
            List<T> prefixElements = (List<T>) Arrays.asList(elements.toArray());
            return new ToList<>(prefixEst, splitPos, prefixSpliterators, prefixCollections, prefixElements);
        }

        @Override
        void accumulate(int pos, T t) {
            elements.set(pos, t);
        }
    }

    boolean advance(int i) {
        if (spliterators[i] == null) {
            if (i > 0 && collections[i - 1] != null && !advance(i - 1))
                return false;
            spliterators[i] = collections[i].spliterator();
        }
        Consumer<? super T> action = t -> accumulate(i, t);
        if (!spliterators[i].tryAdvance(action)) {
            if (i == 0 || collections[i - 1] == null || !advance(i - 1))
                return false;
            spliterators[i] = collections[i].spliterator();
            return spliterators[i].tryAdvance(action);
        }
        return true;
    }

    @Override
    public Spliterator<A> trySplit() {
        if (spliterators[splitPos] == null)
            spliterators[splitPos] = collections[splitPos].spliterator();
        Spliterator<T> res = spliterators[splitPos].trySplit();
        if (res == null) {
            if (splitPos == spliterators.length - 1)
                return null;
            @SuppressWarnings("unchecked")
            T[] arr = (T[]) StreamSupport.stream(spliterators[splitPos], false).toArray();
            if (arr.length == 0)
                return null;
            if (arr.length == 1) {
                accumulate(splitPos, arr[0]);
                splitPos++;
                return trySplit();
            }
            spliterators[splitPos] = Spliterators.spliterator(arr, Spliterator.ORDERED);
            return trySplit();
        }
        long prefixEst = Long.MAX_VALUE;
        long newEst = spliterators[splitPos].getExactSizeIfKnown();
        if (newEst == -1) {
            newEst = Long.MAX_VALUE;
        } else {
            try {
                for (int i = splitPos + 1; i < collections.length; i++) {
                    long size = collections[i].size();
                    newEst = StrictMath.multiplyExact(newEst, size);
                }
                if (est != Long.MAX_VALUE)
                    prefixEst = est - newEst;
            } catch (ArithmeticException e) {
                newEst = Long.MAX_VALUE;
            }
        }
        Spliterator<T>[] prefixSpliterators = spliterators.clone();
        Collection<T>[] prefixCollections = collections.clone();
        prefixSpliterators[splitPos] = res;
        this.est = newEst;
        Arrays.fill(spliterators, splitPos + 1, spliterators.length, null);
        return doSplit(prefixEst, prefixSpliterators, prefixCollections);
    }

    @Override
    public long estimateSize() {
        return est;
    }

    @Override
    public int characteristics() {
        int sized = est < Long.MAX_VALUE ? SIZED : 0;
        return ORDERED | sized;
    }
}