package com.jessecoyle;

import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient;
import com.amazonaws.services.dynamodbv2.model.*;
import com.amazonaws.services.kms.AWSKMS;
import com.amazonaws.services.kms.AWSKMSClient;
import com.amazonaws.services.kms.model.DecryptRequest;
import com.amazonaws.services.kms.model.DecryptResult;
import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
import com.amazonaws.services.kms.model.GenerateDataKeyResult;
import org.apache.commons.codec.binary.Hex;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.internal.verification.VerificationModeFactory;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

public class JCredStashTest {

    private AmazonDynamoDB dynamoDBClient;
    private AWSKMS awskmsClient;

    @Before
    public void setUp() {
        dynamoDBClient = Mockito.mock(AmazonDynamoDB.class);

        GenerateDataKeyResult generateDatakeyResult = new GenerateDataKeyResult();
        generateDatakeyResult.setCiphertextBlob(Mockito.mock(ByteBuffer.class));
        generateDatakeyResult.setPlaintext(Mockito.mock(ByteBuffer.class));

        DecryptResult decryptResult = new DecryptResult();
        decryptResult.setKeyId("alias/foo");
        decryptResult.setPlaintext(Mockito.mock(ByteBuffer.class));

        awskmsClient = Mockito.mock(AWSKMS.class);
        Mockito.when(awskmsClient.generateDataKey(Mockito.any(GenerateDataKeyRequest.class))).thenReturn(generateDatakeyResult);
        Mockito.when(awskmsClient.decrypt(Mockito.any(DecryptRequest.class))).thenReturn(decryptResult);
    }

    @Test
    public void testPutSecretDefaultVersion() {
        final PutItemRequest[] putItemRequest = new PutItemRequest[1];
        Mockito.when(dynamoDBClient.putItem(Mockito.any(PutItemRequest.class))).thenAnswer(invocationOnMock -> {
            Object[] args = invocationOnMock.getArguments();
            putItemRequest[0] = (PutItemRequest) args[0];
            return new PutItemResult();
        });

        JCredStash credStash = new JCredStash(dynamoDBClient, awskmsClient);
        credStash.putSecret("table", "mysecret", "foo", "alias/foo", new HashMap<>(), null);

        Mockito.verify(dynamoDBClient, VerificationModeFactory.times(1)).putItem(Mockito.any(PutItemRequest.class));
        Assert.assertEquals(putItemRequest[0].getItem().get("version").getS(), padVersion(1));
    }

    @Test
    public void testPutSecretNewVersion() {
        String version = "foover";
        final PutItemRequest[] putItemRequest = new PutItemRequest[1];
        Mockito.when(dynamoDBClient.putItem(Mockito.any(PutItemRequest.class))).thenAnswer(invocationOnMock -> {
            Object[] args = invocationOnMock.getArguments();
            putItemRequest[0] = (PutItemRequest) args[0];
            return new PutItemResult();
        });

        JCredStash credStash = new JCredStash(dynamoDBClient, awskmsClient);
        credStash.putSecret("table", "mysecret", "foo", "alias/foo", new HashMap<>(), version);

        Mockito.verify(dynamoDBClient, VerificationModeFactory.times(1)).putItem(Mockito.any(PutItemRequest.class));
        Assert.assertEquals(putItemRequest[0].getItem().get("version").getS(), version);
    }

    @Test
    public void testPutSecretAutoIncrementVersion() {
        final PutItemRequest[] putItemRequest = new PutItemRequest[1];
        Mockito.when(dynamoDBClient.putItem(Mockito.any(PutItemRequest.class))).thenAnswer(invocationOnMock -> {
            Object[] args = invocationOnMock.getArguments();
            putItemRequest[0] = (PutItemRequest) args[0];
            return new PutItemResult();
        });

        JCredStash credStash = Mockito.spy(new JCredStash(dynamoDBClient, awskmsClient));
        Mockito.doReturn(padVersion(1)).when(credStash).getHighestVersion("table", "mysecret");
        credStash.putSecret("table", "mysecret", "foo", "alias/foo", new HashMap<>());

        Mockito.verify(dynamoDBClient, VerificationModeFactory.times(1)).putItem(Mockito.any(PutItemRequest.class));
        Assert.assertEquals(putItemRequest[0].getItem().get("version").getS(), padVersion(2));
    }

    protected Map<String, AttributeValue> mockItem(String secretName, String newVersion, byte[] encryptedKeyBytes, byte[] contents, byte[] hmac) {

        Map<String, AttributeValue> item = new HashMap<>();
        item.put("name", new AttributeValue(secretName));
        item.put("version", new AttributeValue(newVersion));
        item.put("key", new AttributeValue(new String(Base64.getEncoder().encode(encryptedKeyBytes))));
        item.put("contents", new AttributeValue(new String(Base64.getEncoder().encode(contents))));
        item.put("hmac", new AttributeValue(new String(Hex.encodeHex(hmac))));
        return item;
    }

    @Test
    public void testGetSecret() {
        final QueryRequest[] queryRequest = new QueryRequest[1];
        Mockito.when(dynamoDBClient.query(Mockito.any(QueryRequest.class))).thenAnswer(invocationOnMock -> {
            Object[] args = invocationOnMock.getArguments();
            queryRequest[0] = (QueryRequest) args[0];
            return new QueryResult().withCount(1).withItems(Arrays.asList(
                    mockItem("mysecret", padVersion(1), new byte[]{}, new byte[]{}, new byte[]{})
            ));
        });


        JCredStash credStash = Mockito.spy(new JCredStash(dynamoDBClient, awskmsClient));

        Mockito.doReturn("foo").when(credStash).getSecret(Mockito.any(JCredStash.StoredSecret.class), Mockito.any(Map.class));

        String secret = credStash.getSecret("table", "mysecret", new HashMap<>());

        Mockito.verify(dynamoDBClient, VerificationModeFactory.times(1)).query(Mockito.any(QueryRequest.class));
        Assert.assertEquals("foo", secret);
    }

    @Test
    public void testGetSecretWithVersion() {
        final GetItemRequest[] getItemRequest = new GetItemRequest[1];
        Mockito.when(dynamoDBClient.getItem(Mockito.any(GetItemRequest.class))).thenAnswer(invocationOnMock -> {
            Object[] args = invocationOnMock.getArguments();
            getItemRequest[0] = (GetItemRequest) args[0];
            return new GetItemResult();
        });

        JCredStash credStash = Mockito.spy(new JCredStash(dynamoDBClient, awskmsClient));
        Mockito.doReturn("foo").when(credStash).getSecret(Mockito.any(JCredStash.StoredSecret.class), Mockito.any(Map.class));

        credStash.getSecret("table", "mysecret", new HashMap<>(), padVersion(1));

        Mockito.verify(dynamoDBClient, VerificationModeFactory.times(1)).getItem(Mockito.any(GetItemRequest.class));
        Assert.assertEquals(getItemRequest[0].getKey().get("version").getS(), padVersion(1));
    }

    private String padVersion(int v) {
        return String.format("%019d", v);
    }
}