package com.github.toastshaman.dropwizard.auth.jwt;

import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheBuilderSpec;
import com.google.common.cache.CacheStats;
import io.dropwizard.auth.AuthenticationException;
import io.dropwizard.auth.Authenticator;
import org.jose4j.jwt.consumer.JwtContext;

import java.security.Principal;
import java.util.AbstractMap.SimpleEntry;
import java.util.Optional;
import java.util.function.Predicate;

import static com.codahale.metrics.MetricRegistry.name;

public class CachingJwtAuthenticator<P extends Principal> implements Authenticator<JwtContext, P> {

    private final Authenticator<JwtContext, P> authenticator;
    private final Cache<String, SimpleEntry<JwtContext, Optional<P>>> cache;
    private final Meter cacheMisses;
    private final Timer gets;

    /**
     * Creates a new cached authenticator.
     *
     * @param metricRegistry the application's registry of metrics
     * @param authenticator  the underlying authenticator
     * @param cacheSpec      a {@link CacheBuilderSpec}
     */
    public CachingJwtAuthenticator(final MetricRegistry metricRegistry,
                                   final Authenticator<JwtContext, P> authenticator,
                                   final CacheBuilderSpec cacheSpec) {
        this(metricRegistry, authenticator, CacheBuilder.from(cacheSpec));
    }

    /**
     * Creates a new cached authenticator.
     *
     * @param metricRegistry the application's registry of metrics
     * @param authenticator  the underlying authenticator
     * @param builder        a {@link CacheBuilder}
     */
    public CachingJwtAuthenticator(final MetricRegistry metricRegistry,
                                   final Authenticator<JwtContext, P> authenticator,
                                   final CacheBuilder<Object, Object> builder) {
        this.authenticator = authenticator;
        this.cacheMisses = metricRegistry.meter(name(authenticator.getClass(), "cache-misses"));
        this.gets = metricRegistry.timer(name(authenticator.getClass(), "gets"));
        this.cache = builder.recordStats().build();
    }

    @Override
    public Optional<P> authenticate(JwtContext context) throws AuthenticationException {
        final Timer.Context timer = gets.time();
        try {
            final SimpleEntry<JwtContext, Optional<P>> cacheEntry = cache.getIfPresent(context.getJwt());
            if (cacheEntry != null) {
                return cacheEntry.getValue();
            }

            cacheMisses.mark();
            final Optional<P> principal = authenticator.authenticate(context);
            if (principal.isPresent()) {
                cache.put(context.getJwt(), new SimpleEntry<>(context, principal));
            }
            return principal;
        }
        finally { timer.stop(); }
    }

    /**
     * Discards any cached principal for the given credentials.
     *
     * @param credentials a set of credentials
     */
    public void invalidate(JwtContext credentials) {
        cache.invalidate(credentials.getJwt());
    }

    /**
     * Discards any cached principal for the given collection of credentials.
     *
     * @param credentials a collection of credentials
     */
    public void invalidateAll(Iterable<JwtContext> credentials) {
        credentials.forEach(context -> cache.invalidate(context.getJwt()));
    }

    /**
     * Discards any cached principal for the collection of credentials satisfying the given predicate.
     *
     * @param predicate a predicate to filter credentials
     */
    public void invalidateAll(Predicate<? super JwtContext> predicate) {
        cache.asMap().entrySet().stream()
            .map(entry -> entry.getValue().getKey())
            .filter(predicate::test)
            .map(JwtContext::getJwt)
            .forEach(cache::invalidate);
    }

    /**
     * Discards all cached principals.
     */
    public void invalidateAll() {
        cache.invalidateAll();
    }

    /**
     * Returns the number of cached principals.
     *
     * @return the number of cached principals
     */
    public long size() {
        return cache.size();
    }

    /**
     * Returns a set of statistics about the cache contents and usage.
     *
     * @return a set of statistics about the cache contents and usage
     */
    public CacheStats stats() {
        return cache.stats();
    }
}