package es.moki.ratelimitj.test.limiter.request;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import es.moki.ratelimitj.core.limiter.request.RequestLimitRule;
import es.moki.ratelimitj.core.limiter.request.RequestRateLimiter;
import es.moki.ratelimitj.core.time.TimeSupplier;
import es.moki.ratelimitj.test.time.TimeBanditSupplier;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.assertj.core.api.Assertions.assertThat;

@SuppressWarnings("PMD.AvoidUsingHardCodedIP")
public abstract class AbstractSyncRequestRateLimiterTest {

    private final TimeBanditSupplier timeBandit = new TimeBanditSupplier();

    protected abstract RequestRateLimiter getRateLimiter(Set<RequestLimitRule> rules, TimeSupplier timeSupplier);

    @Test
    void shouldLimitSingleWindowSync()  {

        ImmutableSet<RequestLimitRule> rules = ImmutableSet.of(RequestLimitRule.of(Duration.ofSeconds(10), 5));
        RequestRateLimiter requestRateLimiter = getRateLimiter(rules, timeBandit);

        IntStream.rangeClosed(1, 5).forEach(value -> {
            timeBandit.addUnixTimeMilliSeconds(1000L);
            assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.1.1")).isFalse();
        });

        assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.1.1")).isTrue();
    }

    @Test
    void shouldGeLimitSingleWindowSync() {

        ImmutableSet<RequestLimitRule> rules = ImmutableSet.of(RequestLimitRule.of(Duration.ofSeconds(10), 5));
        RequestRateLimiter requestRateLimiter = getRateLimiter(rules, timeBandit);

        IntStream.rangeClosed(1, 4).forEach(value -> {
            timeBandit.addUnixTimeMilliSeconds(1000L);
            assertThat(requestRateLimiter.geLimitWhenIncremented("ip:127.0.1.2")).isFalse();
        });

        assertThat(requestRateLimiter.geLimitWhenIncremented("ip:127.0.1.2")).isTrue();
    }

    @Test
    void shouldLimitWithWeightSingleWindowSync() {

        ImmutableSet<RequestLimitRule> rules = ImmutableSet.of(RequestLimitRule.of(Duration.ofSeconds(10), 10));
        RequestRateLimiter requestRateLimiter = getRateLimiter(rules, timeBandit);

        IntStream.rangeClosed(1, 5).forEach(value -> {
            timeBandit.addUnixTimeMilliSeconds(1000L);
            assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.1.2", 2)).isFalse();
        });

        assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.1.2", 2)).isTrue();
    }

    @Test
    void shouldLimitSingleWindowSyncWithMultipleKeys() {

        ImmutableSet<RequestLimitRule> rules = ImmutableSet.of(RequestLimitRule.of(Duration.ofSeconds(10), 5));
        RequestRateLimiter requestRateLimiter = getRateLimiter(rules, timeBandit);

        IntStream.rangeClosed(1, 5).forEach(value -> {
            timeBandit.addUnixTimeMilliSeconds(1000L);
            IntStream.rangeClosed(1, 10).forEach(
                    keySuffix -> assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.0." + keySuffix)).isFalse());
        });

        IntStream.rangeClosed(1, 10).forEach(
                keySuffix -> assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.0." + keySuffix)).isTrue());

        timeBandit.addUnixTimeMilliSeconds(5000L);
        IntStream.rangeClosed(1, 10).forEach(
                keySuffix -> assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.0.0." + keySuffix)).isFalse());
    }

    @Test
    void shouldLimitSingleWindowSyncWithKeySpecificRules() {

        RequestLimitRule rule1 = RequestLimitRule.of(Duration.ofSeconds(10), 5).matchingKeys("ip:127.9.0.0");
        RequestLimitRule rule2 = RequestLimitRule.of(Duration.ofSeconds(10), 10);

        RequestRateLimiter requestRateLimiter = getRateLimiter(ImmutableSet.of(rule1, rule2), timeBandit);

        IntStream.rangeClosed(1, 5).forEach(value -> {
            timeBandit.addUnixTimeMilliSeconds(1000L);
            assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.9.0.0")).isFalse();
        });
        assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.9.0.0")).isTrue();

        IntStream.rangeClosed(1, 10).forEach(value -> assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.9.1.0")).isFalse());
        assertThat(requestRateLimiter.overLimitWhenIncremented("ip:127.9.1.0")).isTrue();
    }

    @Test
    void shouldResetLimit() {
        ImmutableSet<RequestLimitRule> rules = ImmutableSet.of(RequestLimitRule.of(Duration.ofSeconds(60), 1));
        RequestRateLimiter requestRateLimiter = getRateLimiter(rules, timeBandit);

        String key = "ip:127.1.0.1";
        assertThat(requestRateLimiter.overLimitWhenIncremented(key)).isFalse();
        assertThat(requestRateLimiter.overLimitWhenIncremented(key)).isTrue();

        assertThat(requestRateLimiter.resetLimit(key)).isTrue();
        assertThat(requestRateLimiter.resetLimit(key)).isFalse();

        assertThat(requestRateLimiter.overLimitWhenIncremented(key)).isFalse();
    }


    @Test
    void shouldRateLimitOverTime() {
        RequestLimitRule rule1 = RequestLimitRule.of(Duration.ofSeconds(5), 250).withPrecision(Duration.ofSeconds(1)).matchingKeys("ip:127.3.9.3");
        RequestRateLimiter requestRateLimiter = getRateLimiter(ImmutableSet.of(rule1), timeBandit);
        AtomicLong timeOfLastOperation = new AtomicLong();

        IntStream.rangeClosed(1, 50).forEach(loop -> {

            IntStream.rangeClosed(1, 250).forEach(value -> {
                timeBandit.addUnixTimeMilliSeconds(14L);
                boolean overLimit = requestRateLimiter.overLimitWhenIncremented("ip:127.3.9.3");
                if (overLimit) {
                    long timeSinceLastOperation = timeBandit.get() - timeOfLastOperation.get();
                    assertThat(timeSinceLastOperation).isLessThan(3);
                } else {
                    timeOfLastOperation.set(timeBandit.get());
                }
            });

        });
    }

    @Test @Disabled
    void shouldPreventThunderingHerdWithPrecision() {

        RequestLimitRule rule1 = RequestLimitRule.of(Duration.ofSeconds(5), 250).withPrecision(Duration.ofSeconds(1)).matchingKeys("ip:127.9.9.9");
        RequestRateLimiter requestRateLimiter = getRateLimiter(ImmutableSet.of(rule1), timeBandit);
        Map<Long, Integer> underPerSecond = new LinkedHashMap<>();
        Map<Long, Integer> overPerSecond = new HashMap<>();

        IntStream.rangeClosed(1, 50).forEach(loop -> {

            IntStream.rangeClosed(1, 250).forEach(value -> {
                timeBandit.addUnixTimeMilliSeconds(14L);
                boolean overLimit = requestRateLimiter.overLimitWhenIncremented("ip:127.9.9.9");
                if (!overLimit) {
                    underPerSecond.merge(timeBandit.get(), 1, Integer::sum);
                } else {
                    overPerSecond.merge(timeBandit.get(), 1, Integer::sum);
                }
            });

        });

        Set<Long> allSeconds = Sets.newTreeSet(Sets.union(underPerSecond.keySet(), overPerSecond.keySet()));

        allSeconds.forEach((k)->System.out.println("Time seconds : " + k + " under count : " + underPerSecond.get(k) + " over count : " + overPerSecond.get(k)));
    }


}