package rsc.parallel;

import java.util.Queue;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.function.Supplier;

import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import rsc.documentation.FusionMode;
import rsc.documentation.FusionSupport;
import rsc.flow.Fuseable;
import rsc.util.BackpressureHelper;
import rsc.util.ExceptionHelper;
import rsc.subscriber.SubscriptionHelper;

/**
 * Dispatches the values from upstream in a round robin fashion to subscribers which are
 * ready to consume elements. A value from upstream is sent to only one of the subscribers.
 *
 * @param <T> the value type
 */
@FusionSupport(input = { FusionMode.SYNC, FusionMode.ASYNC })
public final class ParallelOrderedSource<T> extends ParallelOrderedBase<T> {
    final Publisher<? extends T> source;
    
    final int parallelism;
    
    final int prefetch;
    
    final Supplier<Queue<T>> queueSupplier;

    public ParallelOrderedSource(Publisher<? extends T> source, int parallelism, int prefetch, Supplier<Queue<T>> queueSupplier) {
        this.source = source;
        this.parallelism = parallelism;
        this.prefetch = prefetch;
        this.queueSupplier = queueSupplier;
    }
    
    @Override
    public int parallelism() {
        return parallelism;
    }
    
    @Override
    public void subscribeOrdered(Subscriber<? super OrderedItem<T>>[] subscribers) {
        if (!validate(subscribers)) {
            return;
        }
        
        source.subscribe(new ParallelDispatcher<>(subscribers, prefetch, queueSupplier));
    }
    
    static final class ParallelDispatcher<T> implements Subscriber<T> {

        final Subscriber<? super OrderedItem<T>>[] subscribers;
        
        final AtomicLongArray requests;

        final long[] emissions;

        final int prefetch;
        
        final int limit;

        final Supplier<Queue<T>> queueSupplier;

        Subscription s;
        
        Queue<T> queue;
        
        Throwable error;
        
        volatile boolean done;
        
        int index;
        
        volatile boolean cancelled;
        
        volatile int wip;
        @SuppressWarnings("rawtypes")
        static final AtomicIntegerFieldUpdater<ParallelDispatcher> WIP =
                AtomicIntegerFieldUpdater.newUpdater(ParallelDispatcher.class, "wip");
        
        /** 
         * Counts how many subscribers were setup to delay triggering the
         * drain of upstream until all of them have been setup.
         */
        volatile int subscriberCount;
        @SuppressWarnings("rawtypes")
        static final AtomicIntegerFieldUpdater<ParallelDispatcher> SUBSCRIBER_COUNT =
                AtomicIntegerFieldUpdater.newUpdater(ParallelDispatcher.class, "subscriberCount");
        
        int produced;
        
        int sourceMode;
        
        long primaryIndex;

        public ParallelDispatcher(Subscriber<? super OrderedItem<T>>[] subscribers, int prefetch, Supplier<Queue<T>> queueSupplier) {
            this.subscribers = subscribers;
            this.prefetch = prefetch;
            this.queueSupplier = queueSupplier;
            this.limit = prefetch - (prefetch >> 2);
            this.requests = new AtomicLongArray(subscribers.length);
            this.emissions = new long[subscribers.length];
        }
        
        @Override
        public void onSubscribe(Subscription s) {
            if (SubscriptionHelper.validate(this.s, s)) {
                this.s = s;

                if (s instanceof Fuseable.QueueSubscription) {
                    @SuppressWarnings("unchecked")
                    Fuseable.QueueSubscription<T> qs = (Fuseable.QueueSubscription<T>) s;
                    
                    int m = qs.requestFusion(Fuseable.ANY);
                    
                    if (m == Fuseable.SYNC) {
                        sourceMode = m;
                        queue = qs;
                        done = true;
                        setupSubscribers();
                        drain();
                        return;
                    } else
                    if (m == Fuseable.ASYNC) {
                        sourceMode = m;
                        queue = qs;
                        
                        setupSubscribers();
                        
                        s.request(prefetch);
                        
                        return;
                    }
                }
                
                queue = queueSupplier.get();
                
                setupSubscribers();
                
                s.request(prefetch);
            }
        }
        
        void setupSubscribers() {
            int m = subscribers.length;
            
            for (int i = 0; i < m; i++) {
                if (cancelled) {
                    return;
                }
                int j = i;

                SUBSCRIBER_COUNT.lazySet(this, i + 1);
                
                subscribers[i].onSubscribe(new Subscription() {
                    @Override
                    public void request(long n) {
                        if (SubscriptionHelper.validate(n)) {
                            AtomicLongArray ra = requests;
                            for (;;) {
                                long r = ra.get(j);
                                if (r == Long.MAX_VALUE) {
                                    return;
                                }
                                long u = BackpressureHelper.addCap(r, n);
                                if (ra.compareAndSet(j, r, u)) {
                                    break;
                                }
                            }
                            if (subscriberCount == m) {
                                drain();
                            }
                        }
                    }
                    
                    @Override
                    public void cancel() {
                        ParallelDispatcher.this.cancel();
                    }
                });
            }
        }

        @Override
        public void onNext(T t) {
            if (sourceMode == Fuseable.NONE) {
                if (!queue.offer(t)) {
                    cancel();
                    onError(new IllegalStateException("Queue is full?"));
                    return;
                }
            }
            drain();
        }

        @Override
        public void onError(Throwable t) {
            error = t;
            done = true;
            drain();
        }

        @Override
        public void onComplete() {
            done = true;
            drain();
        }
        
        void cancel() {
            if (!cancelled) {
                cancelled = true;
                this.s.cancel();
                
                if (WIP.getAndIncrement(this) == 0) {
                    queue.clear();
                }
            }
        }
        
        void drainAsync() {
            int missed = 1;
            
            Queue<T> q = queue;
            Subscriber<? super OrderedItem<T>>[] a = this.subscribers;
            AtomicLongArray r = this.requests;
            long[] e = this.emissions;
            int n = e.length;
            int idx = index;
            int consumed = produced;
            long pi = primaryIndex;
            
            for (;;) {

                int notReady = 0;
                
                for (;;) {
                    if (cancelled) {
                        q.clear();
                        return;
                    }
                    
                    boolean d = done;
                    if (d) {
                        Throwable ex = error;
                        if (ex != null) {
                            q.clear();
                            for (Subscriber<?> s : a) {
                                s.onError(ex);
                            }
                            return;
                        }
                    }

                    boolean empty = q.isEmpty();
                    
                    if (d && empty) {
                        for (Subscriber<?> s : a) {
                            s.onComplete();
                        }
                        return;
                    }

                    if (empty) {
                        break;
                    }
                    
                    long ridx = r.get(idx);
                    long eidx = e[idx];
                    if (ridx != eidx) {

                        T v = q.poll();
                        
                        a[idx].onNext(PrimaryOrderedItem.of(v, pi++));
                        
                        e[idx] = eidx + 1;
                        
                        int c = ++consumed;
                        if (c == limit) {
                            consumed = 0;
                            s.request(c);
                        }
                        notReady = 0;
                    } else {
                        notReady++;
                    }
                    
                    idx++;
                    if (idx == n) {
                        idx = 0;
                    }
                    
                    if (notReady == n) {
                        break;
                    }
                }
                
                int w = wip;
                if (w == missed) {
                    index = idx;
                    produced = consumed;
                    primaryIndex = pi;
                    missed = WIP.addAndGet(this, -missed);
                    if (missed == 0) {
                        break;
                    }
                } else {
                    missed = w;
                }
            }
        }
        
        void drainSync() {
            int missed = 1;
            
            Queue<T> q = queue;
            Subscriber<? super OrderedItem<T>>[] a = this.subscribers;
            AtomicLongArray r = this.requests;
            long[] e = this.emissions;
            int n = e.length;
            int idx = index;
            long pi = primaryIndex;
            
            for (;;) {

                int notReady = 0;
                
                for (;;) {
                    if (cancelled) {
                        return;
                    }
                    
                    boolean empty;
                    
                    try {
                        empty = q.isEmpty();
                    } catch (Throwable ex) {
                        ExceptionHelper.throwIfFatal(ex);
                        s.cancel();
                        for (Subscriber<?> s : a) {
                            s.onError(ex);
                        }
                        return;
                    }
                    
                    if (empty) {
                        for (Subscriber<?> s : a) {
                            s.onComplete();
                        }
                        return;
                    }

                    long ridx = r.get(idx);
                    long eidx = e[idx];
                    if (ridx != eidx) {

                        T v;
                        
                        try {
                            v = q.poll();
                        } catch (Throwable ex) {
                            ExceptionHelper.throwIfFatal(ex);
                            s.cancel();
                            for (Subscriber<?> s : a) {
                                s.onError(ex);
                            }
                            return;
                        }
                        
                        a[idx].onNext(PrimaryOrderedItem.of(v, pi++));
                        
                        e[idx] = eidx + 1;
                        
                        notReady = 0;
                    } else {
                        notReady++;
                    }
                    
                    idx++;
                    if (idx == n) {
                        idx = 0;
                    }
                    
                    if (notReady == n) {
                        break;
                    }
                }
                
                int w = wip;
                if (w == missed) {
                    index = idx;
                    primaryIndex = pi;
                    missed = WIP.addAndGet(this, -missed);
                    if (missed == 0) {
                        break;
                    }
                } else {
                    missed = w;
                }
            }
        }
        
        void drain() {
            if (WIP.getAndIncrement(this) != 0) {
                return;
            }
            
            if (sourceMode == Fuseable.SYNC) {
                drainSync();
            } else {
                drainAsync();
            }
        }
    }
}