package prj.sslfacade;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Semaphore;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;

/**
 * This test check if the communication over the SSLFacade works. Actually it was written in the first place more likely to check the library.
 */
public class SSLFacadeTest
{

  public static final String SERVER_TAG = "server";
  public static final String CLIENT_TAG = "client";

  public static final String JKS_FILE_PASSWORD = "123456";
  public static final String JKS_FILE = "src/test/resources/test.jks";

  public static final String END_OF_SESSION = "END_OF_SESSION";
  public static final String END_OF_HANDSHAKE = "END_OF_HANDSHAKE";
  public static final String HELLO_FROM_CLIENT_1 = "Hello from client 1";
  public static final String HELLO_FROM_SERVER_1 = "Hello from server 1";
  public static final String HELLO_FROM_CLIENT_2 = "Hello from client 2";
  public static final String HELLO_FROM_SERVER_2 = "Hello from server 2";
  public static final String HELLO_FROM_CLIENT_3 = "Hello from client 3";

  private final ITaskHandler taskHandler = new DefaultTaskHandler();

  private final CharsetEncoder encoder = Charset.forName("US-ASCII").newEncoder();
  private final CharsetDecoder decoder = Charset.forName("US-ASCII").newDecoder();

  private CharBuffer cleintIn1 = CharBuffer.wrap(HELLO_FROM_CLIENT_1);
  private CharBuffer serverIn1 = CharBuffer.wrap(HELLO_FROM_SERVER_1);
  private CharBuffer cleintIn2 = CharBuffer.wrap(HELLO_FROM_CLIENT_2);
  private CharBuffer serverIn2 = CharBuffer.wrap(HELLO_FROM_SERVER_2);
  private CharBuffer cleintIn3 = CharBuffer.wrap(HELLO_FROM_CLIENT_3);

  private List<String> clientNotifications;
  private List<String> serverNotifications;
  private Semaphore sslClientSem;
  private Semaphore sslServerSem;
  private ISSLFacade sslClient;
  private ISSLFacade sslServer;
  private SSLListener clientListener;
  private SSLListener serverListener;
  private SSLContext sslCtx;

  class SSLListener implements ISSLListener
  {

    private final ISSLFacade sslPeer;
    private final String who;
    private final List<String> notifications;
    private final Semaphore sem;
    private final ByteBuffer buffer = ByteBuffer.allocate(1024 * 5);
    private boolean autoflush = true;

    public SSLListener(final String who, final ISSLFacade ssl, List<String> notifications, final Semaphore sem)
    {
      this.sslPeer = ssl;
      this.who = who;
      this.notifications = notifications;
      this.sem = sem;
    }

    public void setAutoflush(boolean autoflush)
    {
      this.autoflush = autoflush;
    }

    @Override
    public void onWrappedData(ByteBuffer wrappedBytes)
    {
      try {
        log(who + " onWrappedData: pass data " + wrappedBytes + " to buffer " + buffer);
        buffer.put(wrappedBytes);
        if (autoflush) {
          flush();
        }
        log(who + " onWrappedData: data decrypted " + wrappedBytes + "in buffer " + buffer);
      } catch (SSLException ex) {
        log(who + " onWrappedData: Error while sending data to peer; " + ex);
      }
    }

    @Override
    public void onPlainData(ByteBuffer plainBytes)
    {
      log(who + ": received plain data: " + plainBytes);
      try {
        CharBuffer decodedString = decoder.decode(plainBytes);
        log(who + ": String received: " + decodedString);
        notifications.add(decodedString.toString());
        sem.release();
      } catch (CharacterCodingException ex) {
        log(who + ": !ERROR! could not decode data received from peer");
      }
    }

    public void flush() throws SSLException
    {
      buffer.flip();
      ByteBuffer bb = ByteBuffer.allocate(buffer.capacity());
      bb.put(buffer);
      buffer.compact();

      bb.flip();
      sslPeer.decrypt(bb);

    }
  };

  public SSLFacadeTest()
  {
  }

  private static void log(final String message)
  {
    System.out.println("[SSLFacadeTest]: " + message);
  }

  public ISSLFacade createSSL(final String who, boolean client, final List<String> notifications, final Semaphore sem)
  {
    ISSLFacade ssl = new SSLFacade(sslCtx, client, false, taskHandler);
    attachHandshakeListener(who, ssl, notifications, sem);

    return ssl;
  }

  public void attachHandshakeListener(final String who, final ISSLFacade ssl, final List<String> notifications, final Semaphore sem)
  {
    ssl.setHandshakeCompletedListener(new IHandshakeCompletedListener()
    {
      @Override
      public void onComplete()
      {
        log(who + ": Handshake completed.");
        notifications.add(END_OF_HANDSHAKE);
        sem.release();
        log(who + ": semaphore released " + sem);
      }
    });
  }

  private SSLListener crateListener(final String who, final ISSLFacade sslPeer, final List<String> notificatons, final Semaphore sem)
  {
    return new SSLListener(who, sslPeer, notificatons, sem);
  }

  private void attachSessionCloseListener(final String who, final ISSLFacade sslServer, final List<String> notifications, final Semaphore sem)
  {
    sslServer.setCloseListener(new ISessionClosedListener()
    {
      public void onSessionClosed()
      {
        log(who + ": peer closed the session. Post notification on sem : " + sem);
        notifications.add(END_OF_SESSION);
        sem.release();
        log(who + ": peer closed the session. Sem notified : " + sem);
      }
    });
  }

  @Before
  public void setUp() throws IOException, NoSuchAlgorithmException, KeyStoreException, CertificateException, UnrecoverableKeyException, KeyManagementException
  {
    KeyStore ks = KeyStore.getInstance("JKS");
    KeyStore ts = KeyStore.getInstance("JKS");
    String keyStoreFile = JKS_FILE;
    String trustStoreFile = JKS_FILE;
    String passw = JKS_FILE_PASSWORD;

    char[] passphrase = passw.toCharArray();

    ks.load(new FileInputStream(keyStoreFile), passphrase);

    ts.load(new FileInputStream(trustStoreFile), passphrase);

    KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
    kmf.init(ks, passphrase);

    TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
    tmf.init(ts);

    sslCtx = SSLContext.getInstance("TLS");
    sslCtx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);

    clientNotifications = new LinkedList<String>();
    serverNotifications = new LinkedList<String>();

    sslClientSem = new Semaphore(0);
    sslServerSem = new Semaphore(0);

    sslClient = createSSL(CLIENT_TAG, true, clientNotifications, sslClientSem);
    sslServer = createSSL(SERVER_TAG, false, serverNotifications, sslServerSem);

    log("== Init SSL listeners");
    clientListener = crateListener(CLIENT_TAG, sslServer, clientNotifications, sslClientSem);
    serverListener = crateListener(SERVER_TAG, sslClient, serverNotifications, sslServerSem);
    sslClient.setSSLListener(clientListener);
    sslServer.setSSLListener(serverListener);

    cleintIn1 = CharBuffer.wrap(HELLO_FROM_CLIENT_1);
    serverIn1 = CharBuffer.wrap(HELLO_FROM_SERVER_1);
    cleintIn2 = CharBuffer.wrap(HELLO_FROM_CLIENT_2);
    serverIn2 = CharBuffer.wrap(HELLO_FROM_SERVER_2);
    cleintIn3 = CharBuffer.wrap(HELLO_FROM_CLIENT_3);

  }

  /**
   * @throws javax.net.ssl.SSLException
   * @throws java.nio.charset.CharacterCodingException
   * @throws java.lang.InterruptedException
   */
  @Test
  public void check_simpleCommunicationScenario() throws SSLException, CharacterCodingException, InterruptedException, IOException
  {
    // given

    // when
    log("== Client started handshake");
    sslClient.beginHandshake();
    log("== Server started handshake");
    sslServer.beginHandshake();

    log("== Client waits untill handshake is done on " + sslClientSem);
    sslClientSem.acquire();

    log("== Server waits untill handshake is done on " + sslServerSem);
    sslServerSem.acquire();

    log("== Sending first message (full duplex)");
    sslClient.encrypt(encoder.encode(cleintIn1));
    sslServer.encrypt(encoder.encode(serverIn1));

    log("== Wait untill the first message arrived");
    sslClientSem.acquire();
    sslServerSem.acquire();

    log("== Sending second message to server");
    sslClient.encrypt(encoder.encode(cleintIn2));
    sslServerSem.acquire();

    log("== Sending second message to client");
    sslServer.encrypt(encoder.encode(serverIn2));
    sslClientSem.acquire();

    log("== Close connection on client side");
    attachSessionCloseListener(CLIENT_TAG, sslClient, clientNotifications, sslClientSem);
    attachSessionCloseListener(SERVER_TAG, sslServer, serverNotifications, sslServerSem);
    sslClient.close();

    log("== Wait server has received end of session on sem " + sslClientSem);
    sslServerSem.acquire();

    //then
    Assertions.assertThat(clientNotifications)
            .hasSize(4)
            .containsExactly(END_OF_HANDSHAKE, HELLO_FROM_SERVER_1, HELLO_FROM_SERVER_2, END_OF_SESSION);

    Assertions.assertThat(serverNotifications)
            .hasSize(4)
            .containsExactly(END_OF_HANDSHAKE, HELLO_FROM_CLIENT_1, HELLO_FROM_CLIENT_2, END_OF_SESSION);
  }

  /**
   * @throws javax.net.ssl.SSLException
   * @throws java.nio.charset.CharacterCodingException
   * @throws java.lang.InterruptedException
   */
  @Test
  public void shall_transferSeriesOfMessages() throws SSLException, CharacterCodingException, InterruptedException, IOException
  {
    // given

    // when
    log("== Client started handshake");
    sslClient.beginHandshake();
    log("== Server started handshake");
    sslServer.beginHandshake();

    log("== Client waits untill handshake is done on " + sslClientSem);
    sslClientSem.acquire();

    log("== Server waits untill handshake is done on " + sslServerSem);
    sslServerSem.acquire();

    log("== Sending messages");
    clientListener.setAutoflush(false);
    sslClient.encrypt(encoder.encode(cleintIn1));
    sslClient.encrypt(encoder.encode(cleintIn2));
    sslClient.encrypt(encoder.encode(cleintIn3));
    clientListener.flush(); // check what happends if all encoded data is passed in one message
    
    // Set the autoflush back so the close operation shoudl be done.
    clientListener.setAutoflush(true);
    
    log("== Wait untill all messages arrived");
    sslServerSem.acquire(3);

    log("== Close connection on client side");
    attachSessionCloseListener(CLIENT_TAG, sslClient, clientNotifications, sslClientSem);
    attachSessionCloseListener(SERVER_TAG, sslServer, serverNotifications, sslServerSem);
    sslClient.close();

    log("== Wait server has received end of session on sem " + sslClientSem);
    sslServerSem.acquire();

    //then
    Assertions.assertThat(clientNotifications)
            .containsExactly(END_OF_HANDSHAKE, END_OF_SESSION);

    Assertions.assertThat(serverNotifications)
            .containsExactly(END_OF_HANDSHAKE, HELLO_FROM_CLIENT_1, HELLO_FROM_CLIENT_2, HELLO_FROM_CLIENT_3, END_OF_SESSION);
  }

  
  /**
   * @throws javax.net.ssl.SSLException
   * @throws java.nio.charset.CharacterCodingException
   * @throws java.lang.InterruptedException
   */
  @Test
  public void shall_transferBigStreamOfMessages() throws SSLException, CharacterCodingException, InterruptedException, IOException
  {
    // given

    // when
    log("== Client started handshake");
    sslClient.beginHandshake();
    log("== Server started handshake");
    sslServer.beginHandshake();

    log("== Client waits untill handshake is done on " + sslClientSem);
    sslClientSem.acquire();

    log("== Server waits untill handshake is done on " + sslServerSem);
    sslServerSem.acquire();

    log("== Sending messages");
    clientListener.setAutoflush(false);
    sslClient.encrypt(encoder.encode(cleintIn1));
    sslClient.encrypt(encoder.encode(cleintIn2));
    sslClient.encrypt(encoder.encode(cleintIn3));
    clientListener.flush(); // check what happends if all encoded data is passed in one message
    
    // Set the autoflush back so the close operation shoudl be done.
    clientListener.setAutoflush(true);
    
    log("== Wait untill all messages arrived");
    sslServerSem.acquire(3);

    log("== Close connection on client side");
    attachSessionCloseListener(CLIENT_TAG, sslClient, clientNotifications, sslClientSem);
    attachSessionCloseListener(SERVER_TAG, sslServer, serverNotifications, sslServerSem);
    sslClient.close();

    log("== Wait server has received end of session on sem " + sslClientSem);
    sslServerSem.acquire();

    //then
    Assertions.assertThat(clientNotifications)
            .containsExactly(END_OF_HANDSHAKE, END_OF_SESSION);

    Assertions.assertThat(serverNotifications)
            .containsExactly(END_OF_HANDSHAKE, HELLO_FROM_CLIENT_1, HELLO_FROM_CLIENT_2, HELLO_FROM_CLIENT_3, END_OF_SESSION);
  }

  
  @Test
  public void check_clientModeSet()
  {
    // given
    boolean isClient = true;

    //when
    ISSLFacade fascade = new SSLFacade(sslCtx, isClient, false, taskHandler);

    //then
    Assertions.assertThat(fascade.isClientMode());
  }

  @Test
  public void check_serverModeSet()
  {
    // given
    boolean isClient = true;

    //when
    ISSLFacade fascade = new SSLFacade(sslCtx, isClient, false, taskHandler);

    //then
    Assertions.assertThat(fascade.isClientMode());
  }
}