package com.amazonaws.encryptionsdk.caching;

import static com.amazonaws.encryptionsdk.caching.CacheTestFixtures.createMaterialsResult;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;

import org.junit.Test;

import com.amazonaws.encryptionsdk.DataKey;
import com.amazonaws.encryptionsdk.caching.CryptoMaterialsCache.UsageStats;
import com.amazonaws.encryptionsdk.model.DecryptionMaterials;
import com.amazonaws.encryptionsdk.model.EncryptionMaterials;

public class LocalCryptoMaterialsCacheThreadStormTest {

    /*
     * This test tests the behavior of LocalCryptoMaterialsCache under contention at the cache level.
     * We specifically test:
     *
     * 1. Gets and puts of encrypt and decrypt entries, including entries under the same cache ID for encrypt
     * 2. Invalidations
     * 3. Changes to cache capacity
     *
     * Periodically, we verify that the system state is sane. This is done by inspecting the private members of
     * LocalCryptoMaterialsCache and verifying that all cache entries are in the LRU map.
     */

    // Private member accessors
    private static final Function<LocalCryptoMaterialsCache, HashMap<?, ?>> get_cacheMap;
    private static final Function<LocalCryptoMaterialsCache, TreeSet<?>> get_expirationQueue;

    private static <T, R> Function<T, R> getGetter(Class<T> klass, String fieldName) {
        try {
            Field f = klass.getDeclaredField(fieldName);
            f.setAccessible(true);

            return obj -> {
                try {
                    return (R)f.get(obj);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            };
        } catch (Exception e) {
            throw new Error(e);
        }
    }

    static {
        get_cacheMap = getGetter(LocalCryptoMaterialsCache.class, "cacheMap");
        get_expirationQueue = getGetter(LocalCryptoMaterialsCache.class, "expirationQueue");
    }

    public static void assertConsistent(LocalCryptoMaterialsCache cache) {
        synchronized (cache) {
            HashSet<Object> expirationQueue = new HashSet<>(get_expirationQueue.apply(cache));
            HashSet<Object> cacheMap = new HashSet<>(get_cacheMap.apply(cache).values());

            assertEquals("Cache group entries are inconsistent with expiration queue",
                         cacheMap, expirationQueue);
        }
    }

    LocalCryptoMaterialsCache cache;

    // When barrier request = true, all worker threads will join the barrier twice.
    CyclicBarrier barrier;
    volatile boolean barrierRequest = false;
    CountDownLatch stopRequest = new CountDownLatch(1);

    // Decrypt results that _might_ be returned. Note that due to race conditions in the test itself, we might be
    // missing valid cached values here; if a result is in neither forbiddenKeys nor possibleDecrypts, then we must
    // assume that it's allowed to be returned.
    ConcurrentHashMap<ByteBuffer, ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object>> possibleDecrypts = new ConcurrentHashMap<>();

    // The values of the inner map are arbitrary but non-null (we use this effectively like a set)
    ConcurrentHashMap<ByteBuffer, ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object>> possibleEncrypts = new ConcurrentHashMap<>();

    // Counters for debugging the test itself. If null, this debug infrastructure is disabled.
    private ConcurrentHashMap<String, AtomicLong> counters = null; //new ConcurrentHashMap<>();
    void inc(String s) {
        if (counters != null) {
            counters.computeIfAbsent(s, ignored -> new AtomicLong(0)).incrementAndGet();
        }
    }

    private static final EncryptionMaterials BASE_ENCRYPT = CacheTestFixtures.createMaterialsResult();
    private static final DecryptionMaterials BASE_DECRYPT
            = CacheTestFixtures.createDecryptResult(CacheTestFixtures.createDecryptRequest(0));

    private void maybeBarrier() {
        if (barrierRequest) {
            try {
                barrier.await();
                barrier.await();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    // This thread continually adds items to the decrypt cache, logging ones it added.
    // The expectedDecryptMap has multiple items because we don't know if the cache expired the prior one; the
    // decrypt check thread will check and forget/forbid the expected items that were not found.
    public void decryptAddThread() {
        int nItemsBeforeRelax = 200_000;
        int nItems = 0;

        try {
            while (stopRequest.getCount() > 0) {
                maybeBarrier();

                byte[] ref = new byte[3];
                ThreadLocalRandom.current().nextBytes(ref);
                ref[0] = 0;

                CacheTestFixtures.SentinelKey key = new CacheTestFixtures.SentinelKey();
                DecryptionMaterials result = BASE_DECRYPT.toBuilder().setDataKey(
                        new DataKey(key, new byte[0], new byte[0], BASE_DECRYPT.getDataKey().getMasterKey())
                ).build();

                ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object> expectedDecryptMap
                        = possibleDecrypts.computeIfAbsent(ByteBuffer.wrap(ref),
                                                           ignored -> new ConcurrentHashMap<>());

                synchronized (expectedDecryptMap) {
                    cache.putEntryForDecrypt(ref, result, () -> Long.MAX_VALUE);
                    expectedDecryptMap.put(key, this);
                }

                inc("decrypt put");

                if (++nItems >= nItemsBeforeRelax) {
                    Thread.sleep(5);
                    nItems = 0;
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // The decrypt check thread verifies that the decrypt results are sane - specifically, if we don't see an item
    // that is known to have once been added to the cache, we should not see it reappear later.
    public void decryptCheckThread() {
        try {
            while (stopRequest.getCount() > 0) {
                maybeBarrier();

                byte[] ref = new byte[3];
                ThreadLocalRandom.current().nextBytes(ref);
                ref[0] = 0;

                ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object> expectedDecryptMap
                        = possibleDecrypts.computeIfAbsent(ByteBuffer.wrap(ref),
                                                           ignored -> new ConcurrentHashMap<>());

                synchronized (expectedDecryptMap) {
                    CryptoMaterialsCache.DecryptCacheEntry result = cache.getEntryForDecrypt(ref);

                    CacheTestFixtures.SentinelKey cachedKey = null;
                    if (result != null) {
                        inc("decrypt: hit");
                        cachedKey = (CacheTestFixtures.SentinelKey) result.getResult().getDataKey().getKey();
                        if (expectedDecryptMap.containsKey(cachedKey)) {
                            inc("decrypt: found key in expected");
                        } else {
                            fail("decrypt: unexpected key");
                        }
                    } else {
                        inc("decrypt: miss");
                    }

                    for (CacheTestFixtures.SentinelKey expectedKey : expectedDecryptMap.keySet()) {
                        if (cachedKey != expectedKey) {
                            inc("decrypt: prune");
                            expectedDecryptMap.remove(expectedKey);
                        }
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // Continually adds encryption cache entries.
    public void encryptAddThread() {
        int nItemsBeforeRelax = 200_000;
        int nItems = 0;

        try {
            while (stopRequest.getCount() > 0) {
                maybeBarrier();

                byte[] ref = new byte[2];
                ThreadLocalRandom.current().nextBytes(ref);

                EncryptionMaterials result = BASE_ENCRYPT.toBuilder().setCleartextDataKey(new CacheTestFixtures.SentinelKey()).build();
                ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object> keys
                        = possibleEncrypts.computeIfAbsent(ByteBuffer.wrap(ref),
                                                           ignored -> new ConcurrentHashMap<>());
                synchronized (keys) {
                    inc("encrypt: add");

                    cache.putEntryForEncrypt(ref, result, () -> Long.MAX_VALUE, UsageStats.ZERO);
                    keys.put((CacheTestFixtures.SentinelKey) result.getCleartextDataKey(), this);
                }

                if (++nItems >= nItemsBeforeRelax) {
                    Thread.sleep(5);
                    nItems = 0;
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // Verifies that there is no resurrection, as above.
    public void encryptCheckThread() {
        try {
            while (stopRequest.getCount() > 0) {
                maybeBarrier();

                byte[] ref = new byte[2];
                ThreadLocalRandom.current().nextBytes(ref);

                ConcurrentHashMap<CacheTestFixtures.SentinelKey, Object> allowedKeys
                        = possibleEncrypts.computeIfAbsent(ByteBuffer.wrap(ref),
                                                           ignored -> new ConcurrentHashMap<>());

                synchronized (allowedKeys) {
                    HashSet<CacheTestFixtures.SentinelKey> foundKeys = new HashSet<>();
                    CryptoMaterialsCache.EncryptCacheEntry ece = cache.getEntryForEncrypt(ref, UsageStats.ZERO);

                    if (ece != null) {
                        foundKeys.add((CacheTestFixtures.SentinelKey)ece.getResult().getCleartextDataKey());
                    }

                    if (foundKeys.isEmpty()) {
                        inc("encrypt check: empty foundRefs");
                    } else {
                        inc("encrypt check: non-empty foundRefs");
                    }

                    foundKeys.forEach(foundKey -> {
                        if (!allowedKeys.containsKey(foundKey)) {
                            fail("encrypt check: unexpected key; " + allowedKeys + " " + foundKeys);
                        }
                    });

                    allowedKeys.keySet().forEach(allowedKey -> {
                        if (!foundKeys.contains(allowedKey)) {
                            inc("encrypt check: prune");
                            // safe since this is a concurrent map
                            allowedKeys.remove(allowedKey);
                        }
                    });
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // Performs a consistency check of the cache entries vs the LRU tracker periodically. Due to the high overhead
    // of this test, we run it infrequently.
    public void checkThread() {
        try {
            while (!stopRequest.await(5000, TimeUnit.MILLISECONDS)) {
                barrierRequest = true;
                barrier.await();

                assertConsistent(cache);
                inc("consistency check passed");

                barrier.await();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Test
    public void test() throws Exception {
        cache = new LocalCryptoMaterialsCache(100_000);

        ArrayList<CompletableFuture<?>> futures = new ArrayList<>();
        ExecutorService es = Executors.newCachedThreadPool();

        ArrayList<Supplier<CompletableFuture<?>>> starters = new ArrayList<>();

        for (int i = 0; i < 2; i++) {
            starters.add(() -> CompletableFuture.runAsync(this::encryptAddThread, es));
            starters.add(() -> CompletableFuture.runAsync(this::encryptCheckThread, es));
            starters.add(() -> CompletableFuture.runAsync(this::decryptAddThread, es));
            starters.add(() -> CompletableFuture.runAsync(this::decryptCheckThread, es));
        }
        starters.add(() -> CompletableFuture.runAsync(this::checkThread, es));

        barrier = new CyclicBarrier(starters.size());

        try {
            starters.forEach(s -> futures.add(s.get()));

            CompletableFuture<?> metaFuture = CompletableFuture.anyOf(futures.toArray(new CompletableFuture[0]));

            try {
                metaFuture.get(10, TimeUnit.SECONDS);
                fail("unexpected termination");
            } catch (TimeoutException e) {
                // ok
            }
        } finally {
            stopRequest.countDown();
            es.shutdownNow();

            es.awaitTermination(1, TimeUnit.SECONDS);

            if (counters != null) {
                new TreeMap<>(counters).forEach((k, v) -> System.out.println(String.format("%s: %d", k, v.get())));
            }
        }
    }
}