/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.chimera.stream;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.security.SecureRandom;
import java.util.Properties;
import java.util.Random;

import com.intel.chimera.cipher.Cipher;
import com.intel.chimera.cipher.CipherTransformation;
import com.intel.chimera.cipher.JceCipher;
import com.intel.chimera.cipher.Openssl;
import com.intel.chimera.cipher.OpensslCipher;
import com.intel.chimera.utils.ReflectionUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public abstract class AbstractCryptoStreamTest {
  private static final Log LOG= LogFactory.getLog(AbstractCryptoStreamTest.class);

  private final int dataLen = 20000;
  private byte[] data = new byte[dataLen];
  private byte[] encData;
  private Properties props = new Properties();
  protected byte[] key = new byte[16];
  private byte[] iv = new byte[16];
  private int count = 10000;
  protected static int defaultBufferSize = 8192;
  protected static int smallBufferSize = 1024;

  private final String jceCipherClass = JceCipher.class.getName();
  private final String opensslCipherClass = OpensslCipher.class.getName();
  protected CipherTransformation transformation;

  public abstract void setUp() throws IOException;

  @Before
  public void before() throws IOException {
    Random random = new SecureRandom();
    random.nextBytes(data);
    random.nextBytes(key);
    random.nextBytes(iv);
    setUp();
    prepareData();
  }

  /** Test skip. */
  @Test(timeout=120000)
  public void testSkip() throws Exception {
    doSkipTest(jceCipherClass, false);
    doSkipTest(opensslCipherClass, false);

    doSkipTest(jceCipherClass, true);
    doSkipTest(opensslCipherClass, true);
  }

  /** Test byte buffer read with different buffer size. */
  @Test(timeout=120000)
  public void testByteBufferRead() throws Exception {
    doByteBufferRead(jceCipherClass, false);
    doByteBufferRead(opensslCipherClass, false);

    doByteBufferRead(jceCipherClass, true);
    doByteBufferRead(opensslCipherClass, true);
  }

  /** Test byte buffer write. */
  @Test(timeout=120000)
  public void testByteBufferWrite() throws Exception {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    doByteBufferWrite(jceCipherClass, baos, false);
    doByteBufferWrite(opensslCipherClass, baos, false);

    doByteBufferWrite(jceCipherClass, baos, true);
    doByteBufferWrite(opensslCipherClass, baos, true);
  }

  private void doSkipTest(String cipherClass, boolean withChannel) throws IOException {
    InputStream in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    byte[] result = new byte[dataLen];
    int n1 = readAll(in, result, 0, dataLen / 3);

    long skipped = in.skip(dataLen / 3);
    int n2 = readAll(in, result, 0, dataLen);

    Assert.assertEquals(dataLen, n1 + skipped + n2);
    byte[] readData = new byte[n2];
    System.arraycopy(result, 0, readData, 0, n2);
    byte[] expectedData = new byte[n2];
    System.arraycopy(data, dataLen - n2, expectedData, 0, n2);
    Assert.assertArrayEquals(readData, expectedData);

    try {
      skipped = in.skip(-3);
      Assert.fail("Skip Negative length should fail.");
    } catch (IllegalArgumentException e) {
      Assert.assertTrue(e.getMessage().contains("Negative skip length"));
    }

    // Skip after EOF
    skipped = in.skip(3);
    Assert.assertEquals(skipped, 0);

    in.close();
  }

  private void doByteBufferRead(String cipherClass, boolean withChannel) throws Exception {
    // Default buffer size, initial buffer position is 0
    InputStream in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    ByteBuffer buf = ByteBuffer.allocate(dataLen + 100);
    byteBufferReadCheck(in, buf, 0);
    in.close();

    // Default buffer size, initial buffer position is not 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 11);
    in.close();

    // Small buffer size, initial buffer position is 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), smallBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 0);
    in.close();

    // Small buffer size, initial buffer position is not 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), smallBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 11);
    in.close();

    // Direct buffer, default buffer size, initial buffer position is 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    buf = ByteBuffer.allocateDirect(dataLen + 100);
    byteBufferReadCheck(in, buf, 0);
    in.close();

    // Direct buffer, default buffer size, initial buffer position is not 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 11);
    in.close();

    // Direct buffer, small buffer size, initial buffer position is 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), smallBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 0);
    in.close();

    // Direct buffer, small buffer size, initial buffer position is not 0
    in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), smallBufferSize, iv, withChannel);
    buf.clear();
    byteBufferReadCheck(in, buf, 11);
    in.close();
  }

  private void doByteBufferWrite(String cipherClass, ByteArrayOutputStream baos,
                                 boolean withChannel)
      throws Exception {
    baos.reset();
    CryptoOutputStream out =
        getCryptoOutputStream(baos, getCipher(cipherClass), defaultBufferSize,
            iv, withChannel);
    ByteBuffer buf = ByteBuffer.allocateDirect(dataLen / 2);
    buf.put(data, 0, dataLen / 2);
    buf.flip();
    int n1 = out.write(buf);

    buf.clear();
    buf.put(data, n1, dataLen / 3);
    buf.flip();
    int n2 = out.write(buf);

    buf.clear();
    buf.put(data, n1 + n2, dataLen - n1 - n2);
    buf.flip();
    int n3 = out.write(buf);

    Assert.assertEquals(dataLen, n1 + n2 + n3);

    out.flush();

    InputStream in = getCryptoInputStream(new ByteArrayInputStream(encData),
        getCipher(cipherClass), defaultBufferSize, iv, withChannel);
    buf = ByteBuffer.allocate(dataLen + 100);
    byteBufferReadCheck(in, buf, 0);
    in.close();
  }

  private void byteBufferReadCheck(InputStream in, ByteBuffer buf,
      int bufPos) throws Exception {
    buf.position(bufPos);
    int n = ((ReadableByteChannel) in).read(buf);
    Assert.assertEquals(bufPos + n, buf.position());
    byte[] readData = new byte[n];
    buf.rewind();
    buf.position(bufPos);
    buf.get(readData);
    byte[] expectedData = new byte[n];
    System.arraycopy(data, 0, expectedData, 0, n);
    Assert.assertArrayEquals(readData, expectedData);
  }

  private void prepareData() throws IOException {
    Cipher cipher = null;
    try {
      cipher = (Cipher)ReflectionUtils.newInstance(
          ReflectionUtils.getClassByName(jceCipherClass), props, transformation);
    } catch (ClassNotFoundException cnfe) {
      throw new IOException("Illegal crypto cipher!");
    }

    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    OutputStream out = new CryptoOutputStream(baos, cipher, defaultBufferSize, key, iv);
    out.write(data);
    out.flush();
    out.close();
    encData = baos.toByteArray();
  }

  protected CryptoInputStream getCryptoInputStream(ByteArrayInputStream bais,
                                                   Cipher cipher,
                                                   int bufferSize, byte[] iv,
                                                   boolean withChannel) throws
      IOException {
    if (withChannel) {
      return new CryptoInputStream(Channels.newChannel(bais), cipher,
          bufferSize, key, iv);
    } else {
      return new CryptoInputStream(bais, cipher, bufferSize, key, iv);
    }
  }

  protected CryptoOutputStream getCryptoOutputStream(ByteArrayOutputStream baos,
                                                   Cipher cipher,
                                                   int bufferSize, byte[] iv,
                                                   boolean withChannel) throws
      IOException {
    if (withChannel) {
      return new CryptoOutputStream(Channels.newChannel(baos), cipher,
          bufferSize, key, iv);
    } else {
      return new CryptoOutputStream(baos, cipher, bufferSize, key, iv);
    }
  }

  private int readAll(InputStream in, byte[] b, int offset, int len)
      throws IOException {
    int n = 0;
    int total = 0;
    while (n != -1) {
      total += n;
      if (total >= len) {
        break;
      }
      n = in.read(b, offset + total, len - total);
    }

    return total;
  }

  protected Cipher getCipher(String cipherClass) throws IOException {
    try {
      return (Cipher)ReflectionUtils.newInstance(
          ReflectionUtils.getClassByName(cipherClass), props, transformation);
    } catch (ClassNotFoundException cnfe) {
      throw new IOException("Illegal crypto cipher!");
    }
  }

  @Test
  public void testReadWrite() throws Exception {
    Assert.assertEquals(null, Openssl.getLoadingFailureReason());
    doReadWriteTest(0, jceCipherClass, jceCipherClass, iv);
    doReadWriteTest(0, opensslCipherClass, opensslCipherClass, iv);
    doReadWriteTest(count, jceCipherClass, jceCipherClass, iv);
    doReadWriteTest(count, opensslCipherClass, opensslCipherClass, iv);
    doReadWriteTest(count, jceCipherClass, opensslCipherClass, iv);
    doReadWriteTest(count, opensslCipherClass, jceCipherClass, iv);
    // Overflow test, IV: xx xx xx xx xx xx xx xx ff ff ff ff ff ff ff ff
    for(int i = 0; i < 8; i++) {
      iv[8 + i] = (byte) 0xff;
    }
    doReadWriteTest(count, jceCipherClass, jceCipherClass, iv);
    doReadWriteTest(count, opensslCipherClass, opensslCipherClass, iv);
    doReadWriteTest(count, jceCipherClass, opensslCipherClass, iv);
    doReadWriteTest(count, opensslCipherClass, jceCipherClass, iv);
  }

  private void doReadWriteTest(int count, String encCipherClass,
                               String decCipherClass, byte[] iv) throws IOException {
    doReadWriteTestForInputStream(count, encCipherClass, decCipherClass, iv);
    doReadWriteTestForReadableByteChannel(count, encCipherClass, decCipherClass,
        iv);
  }

  private void doReadWriteTestForInputStream(int count, String encCipherClass,
                                             String decCipherClass, byte[] iv) throws IOException {
    Cipher encCipher = getCipher(encCipherClass);
    LOG.debug("Created a cipher object of type: " + encCipherClass);

    // Generate data
    SecureRandom random = new SecureRandom();
    byte[] originalData = new byte[count];
    byte[] decryptedData = new byte[count];
    random.nextBytes(originalData);
    LOG.debug("Generated " + count + " records");

    // Encrypt data
    ByteArrayOutputStream encryptedData = new ByteArrayOutputStream();
    CryptoOutputStream out =
        getCryptoOutputStream(encryptedData, encCipher, defaultBufferSize, iv,
            false);
    out.write(originalData, 0, originalData.length);
    out.flush();
    out.close();
    LOG.debug("Finished encrypting data");

    Cipher decCipher = getCipher(decCipherClass);
    LOG.debug("Created a cipher object of type: " + decCipherClass);

    // Decrypt data
    CryptoInputStream in = getCryptoInputStream(
        new ByteArrayInputStream(encryptedData.toByteArray()), decCipher,
        defaultBufferSize, iv, false);

    // Check
    int remainingToRead = count;
    int offset = 0;
    while (remainingToRead > 0) {
      int n = in.read(decryptedData, offset, decryptedData.length - offset);
      if (n >=0) {
        remainingToRead -= n;
        offset += n;
      }
    }

    Assert.assertArrayEquals("originalData and decryptedData not equal",
        originalData, decryptedData);

    // Decrypt data byte-at-a-time
    in = getCryptoInputStream(
        new ByteArrayInputStream(encryptedData.toByteArray()), decCipher,
        defaultBufferSize, iv, false);

    // Check
    DataInputStream originalIn = new DataInputStream(new BufferedInputStream(new ByteArrayInputStream(originalData)));
    int expected;
    do {
      expected = originalIn.read();
      Assert.assertEquals("Decrypted stream read by byte does not match",
          expected, in.read());
    } while (expected != -1);

    LOG.debug("SUCCESS! Completed checking " + count + " records");
  }

  private void doReadWriteTestForReadableByteChannel(int count,
                                                     String encCipherClass,
                                                     String decCipherClass,
                                                     byte[] iv) throws IOException {
    Cipher encCipher = getCipher(encCipherClass);
    LOG.debug("Created a cipher object of type: " + encCipherClass);

    // Generate data
    SecureRandom random = new SecureRandom();
    byte[] originalData = new byte[count];
    byte[] decryptedData = new byte[count];
    random.nextBytes(originalData);
    LOG.debug("Generated " + count + " records");

    // Encrypt data
    ByteArrayOutputStream encryptedData = new ByteArrayOutputStream();
    CryptoOutputStream out =
        getCryptoOutputStream(encryptedData, encCipher, defaultBufferSize, iv,
            true);
    out.write(originalData, 0, originalData.length);
    out.flush();
    out.close();
    LOG.debug("Finished encrypting data");

    Cipher decCipher = getCipher(decCipherClass);
    LOG.debug("Created a cipher object of type: " + decCipherClass);

    // Decrypt data
    CryptoInputStream in = getCryptoInputStream(
        new ByteArrayInputStream(encryptedData.toByteArray()), decCipher,
        defaultBufferSize, iv, true);

    // Check
    int remainingToRead = count;
    int offset = 0;
    while (remainingToRead > 0) {
      int n = in.read(decryptedData, offset, decryptedData.length - offset);
      if (n >=0) {
        remainingToRead -= n;
        offset += n;
      }
    }

    Assert.assertArrayEquals("originalData and decryptedData not equal",
        originalData, decryptedData);

    // Decrypt data byte-at-a-time
    in = getCryptoInputStream(new ByteArrayInputStream(
        encryptedData.toByteArray()),decCipher,defaultBufferSize,iv,true);

    // Check
    DataInputStream originalIn = new DataInputStream(new BufferedInputStream(new ByteArrayInputStream(originalData)));
    int expected;
    do {
      expected = originalIn.read();
      Assert.assertEquals("Decrypted stream read by byte does not match",
          expected, in.read());
    } while (expected != -1);

    LOG.debug("SUCCESS! Completed checking " + count + " records");
  }
}