/*
 * Copyright (C) 2015-2018 SoftIndex LLC.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.datakernel.csp.eventloop;

import io.datakernel.bytebuf.ByteBuf;
import io.datakernel.bytebuf.ByteBufStrings;
import io.datakernel.csp.ChannelSupplier;
import io.datakernel.csp.binary.BinaryChannelSupplier;
import io.datakernel.csp.binary.ByteBufsDecoder;
import io.datakernel.net.AsyncTcpSocket;
import io.datakernel.net.AsyncTcpSocketNio;
import io.datakernel.net.AsyncTcpSocketSsl;
import io.datakernel.net.SimpleServer;
import io.datakernel.promise.Promise;
import io.datakernel.promise.Promises;
import io.datakernel.test.rules.ActivePromisesRule;
import io.datakernel.test.rules.ByteBufRule;
import io.datakernel.test.rules.EventloopRule;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Consumer;

import static io.datakernel.async.process.AsyncCloseable.CLOSE_EXCEPTION;
import static io.datakernel.bytebuf.ByteBufStrings.wrapAscii;
import static io.datakernel.csp.binary.BinaryChannelSupplier.UNEXPECTED_END_OF_STREAM_EXCEPTION;
import static io.datakernel.promise.TestUtils.await;
import static io.datakernel.promise.TestUtils.awaitException;
import static io.datakernel.test.TestUtils.assertComplete;
import static io.datakernel.test.TestUtils.getFreePort;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;

public final class AsyncTcpSocketSslTest {
	private static final String KEYSTORE_PATH = "./src/test/resources/keystore.jks";
	private static final String KEYSTORE_PASS = "testtest";
	private static final String KEY_PASS = "testtest";

	private static final String TRUSTSTORE_PATH = "./src/test/resources/truststore.jks";
	private static final String TRUSTSTORE_PASS = "testtest";

	private static final String TEST_STRING = "Hello world";

	private static final InetSocketAddress ADDRESS = new InetSocketAddress("localhost", getFreePort());

	private static final ByteBufsDecoder<String> DECODER = ByteBufsDecoder.ofFixedSize(TEST_STRING.length())
			.andThen(ByteBuf::asArray)
			.andThen(ByteBufStrings::decodeAscii);

	public static final int LARGE_STRING_SIZE = 10_000;
	public static final int SMALL_STRING_SIZE = 1000;
	private static final int LENGTH = LARGE_STRING_SIZE + TEST_STRING.length() * SMALL_STRING_SIZE;

	private static final ByteBufsDecoder<String> DECODER_LARGE = ByteBufsDecoder.ofFixedSize(LENGTH)
			.andThen(ByteBuf::asArray)
			.andThen(ByteBufStrings::decodeAscii);

	@ClassRule
	public static final EventloopRule eventloopRule = new EventloopRule();

	@ClassRule
	public static final ByteBufRule byteBufRule = new ByteBufRule();

	@Rule
	public final ActivePromisesRule activePromisesRule = new ActivePromisesRule();

	private Executor executor;
	private SSLContext sslContext;
	private StringBuilder sentData;

	@Before
	public void setUp() throws Exception {
		executor = Executors.newSingleThreadExecutor();
		sslContext = createSslContext();
		sentData = new StringBuilder();
	}

	@Test
	public void testWrite() throws IOException {
		startServer(sslContext, sslSocket -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket))
				.parse(DECODER)
				.whenComplete(sslSocket::close)
				.whenComplete(assertComplete(result -> assertEquals(TEST_STRING, result))));

		await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket ->
						sslSocket.write(wrapAscii(TEST_STRING))
								.whenComplete(sslSocket::close)));
	}

	@Test
	public void testRead() throws IOException {
		startServer(sslContext, sslSocket ->
				sslSocket.write(wrapAscii(TEST_STRING))
						.whenComplete(assertComplete()));

		String result = await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket))
						.parse(DECODER)
						.whenComplete(sslSocket::close)));

		assertEquals(TEST_STRING, result);
	}

	@Test
	public void testLoopBack() throws IOException {
		startServer(sslContext, serverSsl -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(serverSsl))
				.parse(DECODER)
				.then(result -> serverSsl.write(wrapAscii(result)))
				.whenComplete(serverSsl::close)
				.whenComplete(assertComplete()));

		String result = await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket ->
						sslSocket.write(wrapAscii(TEST_STRING))
								.then(() -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket))
										.parse(DECODER))
								.whenComplete(sslSocket::close)));

		assertEquals(TEST_STRING, result);
	}

	@Test
	public void testLoopBackWithEmptyBufs() throws IOException {
		int halfLength = TEST_STRING.length() / 2;
		String TEST_STRING_PART_1 = TEST_STRING.substring(0, halfLength);
		String TEST_STRING_PART_2 = TEST_STRING.substring(halfLength);
		startServer(sslContext, serverSsl -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(serverSsl))
				.parse(DECODER)
				.then(result -> serverSsl.write(wrapAscii(result)))
				.whenComplete(serverSsl::close)
				.whenComplete(assertComplete()));

		String result = await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket ->
						sslSocket.write(wrapAscii(TEST_STRING_PART_1))
								.then(() -> sslSocket.write(ByteBuf.empty()))
								.then(() -> sslSocket.write(wrapAscii(TEST_STRING_PART_2)))
								.then(() -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket))
										.parse(DECODER))
								.whenComplete(sslSocket::close)));

		assertEquals(TEST_STRING, result);
	}

	@Test
	public void sendsLargeAmountOfDataFromClientToServer() throws IOException {
		startServer(sslContext, serverSsl -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(serverSsl))
				.parse(DECODER_LARGE)
				.whenComplete(serverSsl::close)
				.whenComplete(assertComplete(result -> assertEquals(result, sentData.toString()))));

		await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.whenResult(sslSocket ->
						sendData(sslSocket)
								.whenComplete(sslSocket::close)));
	}

	@Test
	public void sendsLargeAmountOfDataFromServerToClient() throws IOException {
		startServer(sslContext, serverSsl ->
				sendData(serverSsl)
						.whenComplete(serverSsl::close)
						.whenComplete(assertComplete()));

		String result = await(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket -> BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket))
						.parse(DECODER_LARGE)
						.whenComplete(sslSocket::close)));

		assertEquals(sentData.toString(), result);
	}

	@Test
	public void testCloseAndOperationAfterClose() throws IOException {
		startServer(sslContext, socket ->
				socket.write(wrapAscii("He"))
						.whenComplete(socket::close)
						.then(() -> socket.write(wrapAscii("ello")))
						.whenComplete(($, e) -> assertSame(CLOSE_EXCEPTION, e)));

		Throwable e = awaitException(AsyncTcpSocketNio.connect(ADDRESS)
				.map(socket -> AsyncTcpSocketSsl.wrapClientSocket(socket, sslContext, executor))
				.then(sslSocket -> {
					BinaryChannelSupplier supplier = BinaryChannelSupplier.of(ChannelSupplier.ofSocket(sslSocket));
					return supplier.parse(DECODER)
							.whenException(supplier::closeEx);
				}));

		assertSame(UNEXPECTED_END_OF_STREAM_EXCEPTION, e);
	}

	@Test
	public void testPeerClosingDuringHandshake() throws IOException {
		ServerSocket listener = new ServerSocket(ADDRESS.getPort());
		Thread serverThread = new Thread(() -> {
			try (Socket ignored = listener.accept()) {
				listener.close();
			} catch (IOException ignored) {
				throw new AssertionError();
			}
		});

		serverThread.start();

		Throwable exception = awaitException(AsyncTcpSocketNio.connect(ADDRESS)
				.whenResult(asyncTcpSocket -> {
					try {
						// noinspection ConstantConditions - Imitating a suddenly closed channel
						asyncTcpSocket.getSocketChannel().close();
					} catch (IOException e) {
						throw new AssertionError();
					}
				})
				.map(tcpSocket -> AsyncTcpSocketSsl.wrapClientSocket(tcpSocket, sslContext, executor))
				.then(socket -> socket.write(ByteBufStrings.wrapUtf8("hello"))));
		assertEquals(CLOSE_EXCEPTION, exception);
	}

	static void startServer(SSLContext sslContext, Consumer<AsyncTcpSocket> logic) throws IOException {
		SimpleServer.create(logic)
				.withSslListenAddress(sslContext, Executors.newSingleThreadExecutor(), ADDRESS)
				.withAcceptOnce()
				.listen();
	}

	static SSLContext createSslContext() throws Exception {
		SSLContext instance = SSLContext.getInstance("TLSv1.2");

		KeyStore keyStore = KeyStore.getInstance("JKS");
		KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
		try (InputStream input = new FileInputStream(new File(KEYSTORE_PATH))) {
			keyStore.load(input, KEYSTORE_PASS.toCharArray());
		}
		kmf.init(keyStore, KEY_PASS.toCharArray());

		KeyStore trustStore = KeyStore.getInstance("JKS");
		TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
		try (InputStream input = new FileInputStream(new File(TRUSTSTORE_PATH))) {
			trustStore.load(input, TRUSTSTORE_PASS.toCharArray());
		}
		tmf.init(trustStore);

		instance.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
		return instance;
	}

	public static String generateLargeString(int size) {
		StringBuilder builder = new StringBuilder();
		Random random = new Random();
		for (int i = 0; i < size; i++) {
			int randNumber = random.nextInt(3);
			if (randNumber == 0) {
				builder.append('a');
			} else if (randNumber == 1) {
				builder.append('b');
			} else if (randNumber == 2) {
				builder.append('c');
			}
		}
		return builder.toString();
	}

	private Promise<?> sendData(AsyncTcpSocket socket) {
		String largeData = generateLargeString(LARGE_STRING_SIZE);
		ByteBuf largeBuf = wrapAscii(largeData);
		sentData.append(largeData);

		return socket.write(largeBuf)
				.then(() -> Promises.loop(SMALL_STRING_SIZE,
						i -> i != 0,
						i -> {
							sentData.append(TEST_STRING);
							return socket.write(wrapAscii(TEST_STRING))
									.async()
									.map($2 -> i - 1);
						}));
	}
	// endregion
}