package com.macasaet.fernet.example.pb;

import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.initMocks;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Objects;
import java.util.Random;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Predicate;

import javax.servlet.http.HttpServletResponse;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;

import com.macasaet.fernet.Key;
import com.macasaet.fernet.Token;
import com.macasaet.fernet.TokenValidationException;
import com.macasaet.fernet.Validator;
import com.macasaet.fernet.example.pb.Example.Session;
import com.macasaet.fernet.example.pb.Example.Session.Builder;

 * This class demonstrates storing binary content in the Fernet token.
 * <p>Copyright &copy; 2017 Carlos Macasaet.</p>
 * @author Carlos Macasaet
public class ProtocolBuffersExampleIT {

    final Charset charset = StandardCharsets.UTF_8;
    final Random random = new SecureRandom();
    final Key key = Key.generateKey(random);

    private Validator<Session> validator = new Validator<Session>() {
        public Function<byte[], Session> getTransformer() {
            return bytes -> {
                try {
                    return Session.parseFrom(bytes);
                } catch (final InvalidProtocolBufferException e) {
                    throw new TokenValidationException("Invalid payload: " + e.getMessage(), e);

        public Predicate<Session> getObjectValidator() {
            return session -> session.getRenewalCount() < 10;

        public TemporalAmount getTimeToLive() {
            return Duration.ofMinutes(5);

     * Start a new session.
     * @return a serialised Fernet token with a {@link Session} embedded in the payload
    public String createSession(@Context final HttpServletResponse servletResponse) {
        final String sessionId = UUID.randomUUID().toString();
        final Builder builder = Session.newBuilder();
        servletResponse.addHeader("Location", "/api/sessions/" + sessionId);
        final Session session =;
        final Token token = Token.generate(random, key, session.toByteArray());
        return token.serialise();

     * Renew a session 
     * @param sessionId the existing session ID
     * @param tokenString a current valid Fernet token
     * @return a new Fernet token with the updated session state embedded
    public String renew(@PathParam("sessionId") final String sessionId, final String tokenString,
            @Context final HttpServletResponse servletResponse) {
        final Token inputToken = Token.fromString(tokenString);
        final Session session = inputToken.validateAndDecrypt(key, validator);
        if (!Objects.equals(sessionId, session.getSessionId())) {
            throw new BadRequestException("SessionID mismatch.");

        final Instant lastRenewed = Instant.ofEpochSecond(session.getLastRenewalTime());
        if (session.hasLastRenewalTime() && lastRenewed.isAfter( {
            // prevent denial-of-service
            // if token was renewed less than a minute ago, tell the client to back off
            servletResponse.addHeader("Retry-After", "60");
            // Too Many Requests:
            throw new WebApplicationException("Try again in a minute", 429);

        // The token and session are valid, now update the session
        final Builder builder = Session.newBuilder(session);
        builder.setRenewalCount(session.getRenewalCount() + 1);
        final Session updatedSession =;
        // store the updated session in a new Fernet token
        final Token retval = Token.generate(random, key, updatedSession.toByteArray());
        return retval.serialise();

    ArgumentCaptor<String> locationHeaderCaptor;

    public void setUp() {

    public final void testRenewal() {
        // given
        final HttpServletResponse initialResponse = mock(HttpServletResponse.class);
        final String initialToken = createSession(initialResponse);
        verify(initialResponse).addHeader(eq("Location"), locationHeaderCaptor.capture());
        final String location = locationHeaderCaptor.getValue();
        final String sessionId = location.substring(location.lastIndexOf('/') + 1);

        // when
        final HttpServletResponse renewalResponse = mock(HttpServletResponse.class);
        final String subsequentToken = renew(sessionId, initialToken, renewalResponse);

        // then
        final Session result = Token.fromString(subsequentToken).validateAndDecrypt(key, validator);
        assertEquals(1, result.getRenewalCount());
        assertEquals(sessionId, result.getSessionId());