package io.github.bucket4j.util;

import io.github.bucket4j.Bucket;

import java.util.concurrent.CountDownLatch;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.junit.Assert.assertTrue;

public class ConsumptionScenario {

    private final CountDownLatch startLatch;
    private final CountDownLatch endLatch;
    private final ConsumerThread[] consumers;
    private final long initializationTimeMillis;
    private final double permittedRatePerSecond;

    public ConsumptionScenario(int threadCount, long workTimeNanos, Supplier<Bucket> bucketSupplier, Function<Bucket, Long> action, double permittedRatePerSecond) {
        this.startLatch = new CountDownLatch(threadCount);
        this.endLatch = new CountDownLatch(threadCount);
        this.consumers = new ConsumerThread[threadCount];
        this.initializationTimeMillis = System.currentTimeMillis();
        this.permittedRatePerSecond = permittedRatePerSecond;
        Bucket bucket = bucketSupplier.get();
        for (int i = 0; i < threadCount; i++) {
            this.consumers[i] = new ConsumerThread(startLatch, endLatch, bucket, workTimeNanos, action);
        }
    }

    public void executeAndValidateRate() throws Exception {
        for (ConsumerThread consumer : consumers) {
            consumer.start();
        }
        endLatch.await();
        long durationMillis = System.currentTimeMillis() - initializationTimeMillis;

        long consumed = 0;
        for (ConsumerThread consumer : consumers) {
            if (consumer.getException() != null) {
                throw consumer.getException();
            } else {
                consumed += consumer.getConsumed();
            }
        }

        double actualRatePerSecond = (double) consumed * 1_000.0d / durationMillis;
        System.out.println("Consumed " + consumed + " tokens in the "
                + durationMillis + " millis, actualRatePerSecond=" + Formatter.format(actualRatePerSecond)
                + ", permitted rate=" + Formatter.format(permittedRatePerSecond));

        String msg = "Actual rate " + Formatter.format(actualRatePerSecond) + " is greater then permitted rate " + Formatter.format(permittedRatePerSecond);
        assertTrue(msg, actualRatePerSecond <= permittedRatePerSecond);
    }

}