package com.merakianalytics.orianna.datapipeline.common.rates;

import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class FixedWindowRateLimiter extends AbstractRateLimiter {
    private class Drainer extends TimerTask {
        private boolean cancelled = false;

        @Override
        public void run() {
            synchronized(resetterLock) {
                if(!cancelled) {
                    permitter.drainPermits();
                }
            }
        }
    }

    private class Resetter extends TimerTask {
        private boolean cancelled = false;

        @Override
        public void run() {
            synchronized(resetterLock) {
                if(!cancelled) {
                    permitter.drainPermits();
                    synchronized(currentlyProcessingLock) {
                        permitter.release(permits - currentlyProcessing);
                    }
                    if(drainer != null) {
                        drainer.cancel();
                    }
                    resetter = null;
                    drainer = null;
                    timer.purge();
                }
            }
        }
    }

    private volatile int currentlyProcessing = 0;
    private final Object currentlyProcessingLock = new Object();
    private volatile Drainer drainer = null;
    private final long epoch;
    private final TimeUnit epochUnit;
    private volatile int permits;
    private final AtomicInteger permitsIssued = new AtomicInteger(0);
    private final Semaphore permitter;
    private volatile Resetter resetter = null;
    private final Object resetterLock = new Object();
    private final Timer timer = new Timer(true);

    public FixedWindowRateLimiter(final int permits, final long epoch, final TimeUnit epochUnit) {
        super(permits, epoch, epochUnit);
        this.permits = permits;
        this.epoch = epoch;
        this.epochUnit = epochUnit;
        permitter = new Semaphore(permits, true);
    }

    @Override
    public void acquire() throws InterruptedException {
        permitter.acquire();

        permitsIssued.incrementAndGet();

        synchronized(resetterLock) {
            if(drainer == null) {
                drainer = new Drainer();
                timer.schedule(drainer, epochUnit.toMillis(epoch));
            }
        }

        synchronized(currentlyProcessingLock) {
            currentlyProcessing += 1;
        }
    }

    @Override
    public boolean acquire(final long timeout, final TimeUnit unit) throws InterruptedException {
        if(timeout <= 0L) {
            acquire();
            return true;
        }

        if(!permitter.tryAcquire(timeout, unit)) {
            return false;
        }

        permitsIssued.incrementAndGet();

        synchronized(currentlyProcessingLock) {
            currentlyProcessing += 1;
        }

        synchronized(resetterLock) {
            if(drainer == null) {
                drainer = new Drainer();
                timer.schedule(drainer, epochUnit.toMillis(epoch));
            }
        }

        return true;
    }

    @Override
    public long getEpoch() {
        return epoch;
    }

    @Override
    public TimeUnit getEpochUnit() {
        return epochUnit;
    }

    @Override
    public int getPermits() {
        synchronized(resetterLock) {
            return permits;
        }
    }

    @Override
    public int permitsIssued() {
        return permitsIssued.get();
    }

    @Override
    public void release() {
        synchronized(currentlyProcessingLock) {
            currentlyProcessing -= 1;
        }

        synchronized(resetterLock) {
            if(resetter == null) {
                resetter = new Resetter();
                timer.schedule(resetter, epochUnit.toMillis(epoch));
            }
        }
    }

    @Override
    public ReservedPermit reserve() throws InterruptedException {
        permitter.acquire();
        permitsIssued.decrementAndGet();

        synchronized(currentlyProcessingLock) {
            currentlyProcessing += 1;
        }

        return new ReservedPermit() {
            @Override
            public void acquire() {
                permitsIssued.incrementAndGet();
            }

            @Override
            public void cancel() {
                synchronized(currentlyProcessingLock) {
                    currentlyProcessing -= 1;
                    permitter.release();
                }
            }
        };
    }

    @Override
    public ReservedPermit reserve(final long timeout, final TimeUnit unit) throws InterruptedException {
        if(!permitter.tryAcquire(timeout, unit)) {
            return null;
        }
        permitsIssued.decrementAndGet();

        synchronized(currentlyProcessingLock) {
            currentlyProcessing += 1;
        }

        return new ReservedPermit() {
            @Override
            public void acquire() {
                permitsIssued.incrementAndGet();
                synchronized(resetterLock) {
                    if(drainer == null) {
                        drainer = new Drainer();
                        timer.schedule(drainer, epochUnit.toMillis(epoch));
                    }
                }
            }

            @Override
            public void cancel() {
                synchronized(currentlyProcessingLock) {
                    currentlyProcessing -= 1;
                    permitter.release();
                }
            }
        };
    }

    @Override
    public void restrict(final long afterTime, final TimeUnit afterUnit, final long forTime, final TimeUnit forUnit) {
        synchronized(resetterLock) {
            if(drainer != null) {
                drainer.cancel();
                drainer.cancelled = true;
            }

            if(resetter != null) {
                resetter.cancel();
                resetter.cancelled = true;
            }

            drainer = new Drainer();
            timer.schedule(drainer, afterUnit.toMillis(afterTime));
            resetter = new Resetter();
            timer.schedule(resetter, afterUnit.toMillis(afterTime) + forUnit.toMillis(forTime));
        }
    }

    @Override
    public void restrictFor(final long time, final TimeUnit unit) {
        synchronized(resetterLock) {
            permitter.drainPermits();

            if(drainer != null) {
                drainer.cancel();
                drainer.cancelled = true;
            }

            if(resetter != null) {
                resetter.cancel();
                resetter.cancelled = true;
            }

            resetter = new Resetter();
            timer.schedule(resetter, unit.toMillis(time));
        }
    }

    @Override
    public void setPermits(final int permits) {
        synchronized(resetterLock) {
            final int difference = permits - this.permits;
            if(difference > 0) {
                permitter.release(difference);
            } else if(difference < 0) {
                if(!permitter.tryAcquire(difference)) {
                    permitter.drainPermits();
                }
            }

            this.permits = permits;
        }
    }
}