//++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// This file is a part of the 'coroutines' project.
// Copyright 2018 Elmar Sonnenschein, esoco GmbH, Flensburg, Germany
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//	  http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
package de.esoco.coroutine.step.nio;

import de.esoco.coroutine.Continuation;
import de.esoco.coroutine.CoroutineException;

import java.io.IOException;

import java.net.SocketAddress;

import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.charset.StandardCharsets;

import java.util.function.BiPredicate;
import java.util.function.Function;


/********************************************************************
 * Implements asynchronous reading from a {@link AsynchronousSocketChannel}.
 *
 * @author eso
 */
public class SocketReceive extends AsynchronousSocketStep
{
	//~ Instance fields --------------------------------------------------------

	private final BiPredicate<Integer, ByteBuffer> pCheckFinished;

	//~ Constructors -----------------------------------------------------------

	/***************************************
	 * Creates a new instance.
	 *
	 * @param fGetSocketAddress A function that provides the target socket
	 *                          address from the current continuation
	 * @param pCheckFinished    A predicate that checks whether receiving is
	 *                          complete by evaluating the byte buffer after
	 *                          reading
	 */
	public SocketReceive(
		Function<Continuation<?>, SocketAddress> fGetSocketAddress,
		BiPredicate<Integer, ByteBuffer>		 pCheckFinished)
	{
		super(fGetSocketAddress);

		this.pCheckFinished = pCheckFinished;
	}

	//~ Static methods ---------------------------------------------------------

	/***************************************
	 * Returns a new predicate to be used with {@link #until(BiPredicate)} that
	 * checks whether a byte buffer contains the complete content of an HTTP
	 * response. The test is performed by calculating the full data size from
	 * the 'Content-Length' attribute in the response header and comparing it
	 * with the buffer position.
	 *
	 * @return A new predicate instance
	 */
	public static BiPredicate<Integer, ByteBuffer> contentFullyRead()
	{
		return new CheckContentLength();
	}

	/***************************************
	 * Suspends until data has been received from a network socket. The data
	 * will be stored in the input {@link ByteBuffer} of the step. If the
	 * capacity of the buffer is reached before the EOF signal is received the
	 * coroutine will be terminated with a {@link CoroutineException}.
	 *
	 * <p>After the data has been fully received {@link ByteBuffer#flip()} will
	 * be invoked on the buffer so that it can be used directly for subsequent
	 * reading from it.</p>
	 *
	 * <p>The returned step only receives the next block of data that is sent by
	 * the remote socket and then continues the coroutine execution. If data
	 * should be read until a certain condition is met a derived step needs to
	 * be created with {@link #until(BiPredicate)}.</p>
	 *
	 * @param  fGetSocketAddress A function that provides the source socket
	 *                           address from the current continuation
	 *
	 * @return A new step instance
	 */
	public static SocketReceive receiveFrom(
		Function<Continuation<?>, SocketAddress> fGetSocketAddress)
	{
		return new SocketReceive(fGetSocketAddress, (r, bb) -> true);
	}

	/***************************************
	 * @see #receiveFrom(Function)
	 */
	public static SocketReceive receiveFrom(SocketAddress rSocketAddress)
	{
		return receiveFrom(c -> rSocketAddress);
	}

	/***************************************
	 * Suspends until data has been received from a previously connected channel
	 * stored in the currently executed coroutine. If no such channel exists the
	 * execution will fail. This invocation is intended to be used for
	 * request-response communication where a receive is always preceded by a
	 * send operation.
	 *
	 * <p>The predicate argument is the same as for the {@link
	 * #until(BiPredicate)} method.</p>
	 *
	 * @param  pCheckFinished A predicate that checks whether the data has been
	 *                        received completely
	 *
	 * @return A new step instance
	 */
	public static SocketReceive receiveUntil(
		BiPredicate<Integer, ByteBuffer> pCheckFinished)
	{
		return receiveFrom((SocketAddress) null).until(pCheckFinished);
	}

	//~ Methods ----------------------------------------------------------------

	/***************************************
	 * Returns a new receive step instance the suspends until data has been
	 * received from a network socket and a certain condition on that data is
	 * met or an end-of-stream signal is received. If the capacity of the buffer
	 * is reached before the receiving is finished the coroutine will fail with
	 * an exception.
	 *
	 * @param  pCheckFinished A predicate that checks whether the data has been
	 *                        received completely
	 *
	 * @return A new step instance
	 */
	public SocketReceive until(BiPredicate<Integer, ByteBuffer> pCheckFinished)
	{
		return new SocketReceive(getSocketAddressFactory(), pCheckFinished);
	}

	/***************************************
	 * {@inheritDoc}
	 */
	@Override
	protected boolean performAsyncOperation(
		int													nBytesReceived,
		AsynchronousSocketChannel							rChannel,
		ByteBuffer											rData,
		ChannelCallback<Integer, AsynchronousSocketChannel> rCallback)
		throws IOException
	{
		boolean bFinished = false;

		if (nBytesReceived >= 0)
		{
			bFinished = pCheckFinished.test(nBytesReceived, rData);
		}

		if (nBytesReceived != -1 && !bFinished && rData.hasRemaining())
		{
			rChannel.read(rData, rData, rCallback);
		}
		else
		{
			checkErrors(rData, nBytesReceived, bFinished);
			rData.flip();
		}

		return bFinished;
	}

	/***************************************
	 * {@inheritDoc}
	 */
	@Override
	protected void performBlockingOperation(
		AsynchronousSocketChannel aChannel,
		ByteBuffer				  rData) throws Exception
	{
		int     nReceived;
		boolean bFinished;

		do
		{
			nReceived = aChannel.read(rData).get();
			bFinished = pCheckFinished.test(nReceived, rData);
		}
		while (nReceived != -1 && !bFinished && rData.hasRemaining());

		checkErrors(rData, nReceived, bFinished);

		rData.flip();
	}

	/***************************************
	 * Checks the received data and throws an exception on errors.
	 *
	 * @param  rData     The received data bytes
	 * @param  nReceived The number of bytes received on the last read
	 * @param  bFinished TRUE if the finish condition is met
	 *
	 * @throws IOException If an error is detected
	 */
	private void checkErrors(ByteBuffer rData, int nReceived, boolean bFinished)
		throws IOException
	{
		if (!bFinished)
		{
			if (nReceived == -1)
			{
				throw new IOException("Received data incomplete");
			}
			else if (!rData.hasRemaining())
			{
				throw new IOException("Buffer capacity exceeded");
			}
		}
	}

	//~ Inner Classes ----------------------------------------------------------

	/********************************************************************
	 * A predicate to check the Content-Length property in an HTTP header.
	 *
	 * @author eso
	 */
	static class CheckContentLength implements BiPredicate<Integer, ByteBuffer>
	{
		//~ Static fields/initializers -----------------------------------------

		private static final String CONTENT_LENGTH_HEADER = "Content-Length: ";

		//~ Instance fields ----------------------------------------------------

		private int nFullLength = -1;

		//~ Methods ------------------------------------------------------------

		/***************************************
		 * {@inheritDoc}
		 */
		@Override
		public boolean test(Integer nReceived, ByteBuffer rBuffer)
		{
			if (nFullLength == -1)
			{
				String sData =
					StandardCharsets.UTF_8.decode(
						(ByteBuffer) rBuffer.duplicate().flip()).toString();

				int nLengthPos = sData.indexOf(CONTENT_LENGTH_HEADER);

				nFullLength = sData.indexOf("\r\n\r\n");

				if (nFullLength == -1)
				{
					throw new IllegalArgumentException("No HTTP header found");
				}

				if (nLengthPos == -1)
				{
					throw new IllegalArgumentException(
						"No content length found");
				}

				int nContentLength =
					Integer.parseInt(
						sData.substring(
							nLengthPos + CONTENT_LENGTH_HEADER.length(),
							sData.indexOf("\r\n", nLengthPos)));

				nFullLength += nContentLength + 4; // 4 = CRLFCRLF
			}

			return rBuffer.position() >= nFullLength;
		}
	}
}