package de.saly.kafka.crypto;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.io.File;
import java.io.FileOutputStream;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import javax.crypto.Cipher;

import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.serialization.ByteArraySerializer;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.StringDeserializer;
import org.apache.kafka.common.serialization.StringSerializer;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;

public class EnDecryptionTest {

    private final static String TOPIC = "cryptedTestTopic";
    private final File pubKey;
    private final File privKey;
    private final byte[] publicKey;
    private final byte[] privateKey;

    public EnDecryptionTest() throws Exception {
        pubKey = File.createTempFile("kafka", "crypto");
        pubKey.deleteOnExit();
        privKey = File.createTempFile("kafka", "crypto");
        privKey.deleteOnExit();

        KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
        keyGen.initialize(2048);
        KeyPair pair = keyGen.genKeyPair();
        publicKey = pair.getPublic().getEncoded();
        privateKey = pair.getPrivate().getEncoded();

        //System.out.println("private key format: "+pair.getPrivate().getFormat()); // PKCS#8
        //System.out.println("public key format: "+pair.getPublic().getFormat()); // X.509

        FileOutputStream fout = new FileOutputStream(pubKey);
        fout.write(publicKey);
        fout.close();

        fout = new FileOutputStream(privKey);
        fout.write(privateKey);
        fout.close();
    }

    @Test
    public void testBasicStandard() throws Exception {
        testBasic("", 128, -1);
    }

    @Test
    public void testAes256() throws Exception {
        Assume.assumeTrue(Cipher.getMaxAllowedKeyLength("AES") >= 256);
        testBasic("", 256, -1);
    }

    @Test
    public void testSHA1() throws Exception {
        testBasic("SHA1", 128, -1);
    }

    @Test
    public void testSHA1_192() throws Exception {
        Assume.assumeTrue(Cipher.getMaxAllowedKeyLength("AES") >= 192);
        testBasic("SHA1", 192, -1);
    }

    @Test
    public void testMD5_192() throws Exception {
        Assume.assumeTrue(Cipher.getMaxAllowedKeyLength("AES") >= 192);
        testBasic("MD5", 192, -1);
    }
    
    @Test
    public void testMSHA512_128() throws Exception {
        testBasic("SHA-512", 128, -1);
    }

    @Test(expected = KafkaException.class)
    public void testInvalidKeySize() throws Exception {
        testBasic("SHA1", 177, -1);
    }

    @Test(expected = KafkaException.class)
    public void testInvalidHashAlgo() throws Exception {
        testBasic("xxx", 128, -1);
    }

    @Test
    public void testBasicInterval1() throws Exception {
        testBasic("", 128, 1);
    }
    
    @Test
    public void testBasicInterval10() throws Exception {
        testBasic("", 128, 10);
    }
    
    @Test
    public void testMultithreadedStandard() throws Exception {
        testMultithreadedBasic(-1);
    }
    
    @Test
    public void testMultithreadedInterval1() throws Exception {
        testMultithreadedBasic(1);
    }
    
    @Test
    public void testMultithreadedInterval1000() throws Exception {
        testMultithreadedBasic(1000);
    }
    
    protected void testMultithreadedBasic(int msgInterval) throws Exception {
        final String str = "The quick brown fox jumps over the lazy dog";

        final Map<String, Object> config = new HashMap<String, Object>();
        config.put(SerdeCryptoBase.CRYPTO_RSA_PRIVATEKEY_FILEPATH, privKey.getAbsolutePath());
        config.put(SerdeCryptoBase.CRYPTO_RSA_PUBLICKEY_FILEPATH, pubKey.getAbsolutePath());
        config.put(EncryptingSerializer.CRYPTO_VALUE_SERIALIZER, StringSerializer.class.getName());
        config.put(DecryptingDeserializer.CRYPTO_VALUE_DESERIALIZER, StringDeserializer.class);
        config.put(EncryptingSerializer.CRYPTO_NEW_KEY_MSG_INTERVAL, String.valueOf(msgInterval));

        final EncryptingSerializer<String> serializer = new EncryptingSerializer<String>();
        serializer.configure(config, false);
        
        final int threadCount = 200;

        final ExecutorService es = Executors.newFixedThreadPool(threadCount);
        final List<Future<Exception>> futures = new ArrayList<Future<Exception>>();

        for (int i = 0; i < threadCount; i++) {
            Future<Exception> f = es.submit(new Callable<Exception>() {

                @Override
                public Exception call() throws Exception {
                    try {
                        final Deserializer<String> deserializer = new DecryptingDeserializer<String>();
                        deserializer.configure(config, false);
                        for(int i=0; i<1000; i++) {
                            final byte[] enc = serializer.serialize(TOPIC, str+i+Thread.currentThread().getName());
                            assertEquals(str+i+Thread.currentThread().getName(), deserializer.deserialize(TOPIC, enc));
                        }
                        return null;
                    } catch (Exception e) {
                        return e;
                    }

                }

            });
            futures.add(f);
        }

        for (Future<Exception> f : futures) {
            try {
                Exception e = f.get();
                if (e != null) {
                    throw e;
                }
            } catch (Exception e) {
                e.printStackTrace();
                throw e;
            }
        }
    }

    protected void testBasic(String hashMethod, int keylen, int msgInterval) throws Exception {

        final Map<String, Object> config = new HashMap<String, Object>();
        config.put(SerdeCryptoBase.CRYPTO_RSA_PRIVATEKEY_FILEPATH, privKey.getAbsolutePath());
        config.put(SerdeCryptoBase.CRYPTO_RSA_PUBLICKEY_FILEPATH, pubKey.getAbsolutePath());
        config.put(EncryptingSerializer.CRYPTO_VALUE_SERIALIZER, ByteArraySerializer.class.getName());
        config.put(DecryptingDeserializer.CRYPTO_VALUE_DESERIALIZER, ByteArrayDeserializer.class);
        config.put(DecryptingDeserializer.CRYPTO_HASH_METHOD, hashMethod);
        config.put(DecryptingDeserializer.CRYPTO_AES_KEY_LEN, String.valueOf(keylen));
        config.put(DecryptingDeserializer.CRYPTO_IGNORE_DECRYPT_FAILURES, "false");
        config.put(EncryptingSerializer.CRYPTO_NEW_KEY_MSG_INTERVAL, String.valueOf(msgInterval));

        final EncryptingSerializer<byte[]> serializer = new EncryptingSerializer<byte[]>();
        serializer.configure(config, false);

        final Deserializer<byte[]> deserializer = new DecryptingDeserializer<byte[]>();
        deserializer.configure(config, false);

        final Random rand = new Random(System.currentTimeMillis());
        for (int i = 0; i < 1000; i++) {
            final byte[] b = new byte[i];
            rand.nextBytes(b);
            Assert.assertArrayEquals(b, deserializer.deserialize(TOPIC, serializer.serialize(TOPIC, b)));
        }

        for (byte i = 0; i < Byte.MAX_VALUE; i++) {
            final byte[] b = new byte[i];
            Arrays.fill(b, i);
            Assert.assertArrayEquals(b, deserializer.deserialize(TOPIC, serializer.serialize(TOPIC, b)));
        }

        serializer.newKey();

        for (int i = 0; i < 100; i++) {
            final byte[] b = new byte[i];
            rand.nextBytes(b);
            Assert.assertArrayEquals(b, deserializer.deserialize(TOPIC, serializer.serialize(TOPIC, b)));
        }
        byte[] plainText = "The quick brown fox jumps over the lazy dog".getBytes("UTF-8");
        byte[] encryptedText = serializer.serialize(TOPIC, plainText);
        assertArrayEquals(SerdeCryptoBase.MAGIC_BYTES, Arrays.copyOfRange(encryptedText, 0, 2));

        assertArrayEquals(plainText, deserializer.deserialize(TOPIC, plainText));

        try {
            deserializer.deserialize(TOPIC, SerdeCryptoBase.MAGIC_BYTES);
            Assert.fail();
        } catch (Exception e) {
            //expected
        }

        config.put(DecryptingDeserializer.CRYPTO_IGNORE_DECRYPT_FAILURES, "true");
        deserializer.configure(config, false);

        try {
            assertArrayEquals(SerdeCryptoBase.MAGIC_BYTES, deserializer.deserialize(TOPIC, SerdeCryptoBase.MAGIC_BYTES));
        } catch (Exception e) {
            Assert.fail(e.toString());
        }
    }
}