package org.apache.flink.queryablestate.network;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.queryablestate.FutureUtils;
import org.apache.flink.queryablestate.network.messages.MessageBody;
import org.apache.flink.queryablestate.network.messages.MessageSerializer;
import org.apache.flink.queryablestate.network.stats.KvStateRequestStats;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption;
import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup;
import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup;
import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel;
import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

 * The base class for every client in the queryable state module.
 * It is using pure netty to send and receive messages of type {@link MessageBody}.
 * @param <REQ> the type of request the client will send.
 * @param <RESP> the type of response the client expects to receive.
public class Client<REQ extends MessageBody, RESP extends MessageBody> {

	private static final Logger LOG = LoggerFactory.getLogger(Client.class);

	/** The name of the client. Used for logging and stack traces.*/
	private final String clientName;

	/** Netty's Bootstrap. */
	private final Bootstrap bootstrap;

	/** The serializer to be used for (de-)serializing messages. */
	private final MessageSerializer<REQ, RESP> messageSerializer;

	/** Statistics tracker. */
	private final KvStateRequestStats stats;

	/** Established connections. */
	private final Map<InetSocketAddress, EstablishedConnection> establishedConnections = new ConcurrentHashMap<>();

	/** Pending connections. */
	private final Map<InetSocketAddress, PendingConnection> pendingConnections = new ConcurrentHashMap<>();

	/** Atomic shut down future. */
	private final AtomicReference<CompletableFuture<Void>> clientShutdownFuture = new AtomicReference<>(null);

	 * Creates a client with the specified number of event loop threads.
	 * @param clientName the name of the client.
	 * @param numEventLoopThreads number of event loop threads (minimum 1).
	 * @param serializer the serializer used to (de-)serialize messages.
	 * @param stats the statistics collector.
	public Client(
			final String clientName,
			final int numEventLoopThreads,
			final MessageSerializer<REQ, RESP> serializer,
			final KvStateRequestStats stats) {

		Preconditions.checkArgument(numEventLoopThreads >= 1,
				"Non-positive number of event loop threads.");

		this.clientName = Preconditions.checkNotNull(clientName);
		this.messageSerializer = Preconditions.checkNotNull(serializer);
		this.stats = Preconditions.checkNotNull(stats);

		final ThreadFactory threadFactory = new ThreadFactoryBuilder()
				.setNameFormat("Flink " + clientName + " Event Loop Thread %d")

		final EventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory);
		final ByteBufAllocator bufferPool = new NettyBufferPool(numEventLoopThreads);

		this.bootstrap = new Bootstrap()
				.option(ChannelOption.ALLOCATOR, bufferPool)
				.handler(new ChannelInitializer<SocketChannel>() {
					protected void initChannel(SocketChannel channel) throws Exception {
								.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
								.addLast(new ChunkedWriteHandler());

	public String getClientName() {
		return clientName;

	public CompletableFuture<RESP> sendRequest(final InetSocketAddress serverAddress, final REQ request) {
		if (clientShutdownFuture.get() != null) {
			return FutureUtils.getFailedFuture(new IllegalStateException(clientName + " is already shut down."));

		EstablishedConnection connection = establishedConnections.get(serverAddress);
		if (connection != null) {
			return connection.sendRequest(request);
		} else {
			PendingConnection pendingConnection = pendingConnections.get(serverAddress);
			if (pendingConnection != null) {
				// There was a race, use the existing pending connection.
				return pendingConnection.sendRequest(request);
			} else {
				// We try to connect to the server.
				PendingConnection pending = new PendingConnection(serverAddress, messageSerializer);
				PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending);

				if (previous == null) {
					// OK, we are responsible to connect.
					bootstrap.connect(serverAddress.getAddress(), serverAddress.getPort()).addListener(pending);
					return pending.sendRequest(request);
				} else {
					// There was a race, use the existing pending connection.
					return previous.sendRequest(request);

	 * Shuts down the client and closes all connections.
	 * <p>After a call to this method, all returned futures will be failed.
	 * @return A {@link CompletableFuture} that will be completed when the shutdown process is done.
	public CompletableFuture<Void> shutdown() {
		final CompletableFuture<Void> newShutdownFuture = new CompletableFuture<>();
		if (clientShutdownFuture.compareAndSet(null, newShutdownFuture)) {

			final List<CompletableFuture<Void>> connectionFutures = new ArrayList<>();

			for (Map.Entry<InetSocketAddress, EstablishedConnection> conn : establishedConnections.entrySet()) {
				if (establishedConnections.remove(conn.getKey(), conn.getValue())) {

			for (Map.Entry<InetSocketAddress, PendingConnection> conn : pendingConnections.entrySet()) {
				if (pendingConnections.remove(conn.getKey()) != null) {

					connectionFutures.toArray(new CompletableFuture<?>[connectionFutures.size()])
			).whenComplete((result, throwable) -> {

				if (throwable != null) {
					LOG.warn("Problem while shutting down the connections at the {}: {}", clientName, throwable);

				if (bootstrap != null) {
					EventLoopGroup group = bootstrap.group();
					if (group != null && !group.isShutdown()) {
						group.shutdownGracefully(0L, 0L, TimeUnit.MILLISECONDS)
								.addListener(finished -> {
									if (finished.isSuccess()) {
									} else {
					} else {
				} else {

			return newShutdownFuture;
		return clientShutdownFuture.get();

	 * A pending connection that is in the process of connecting.
	private class PendingConnection implements ChannelFutureListener {

		/** Lock to guard the connect call, channel hand in, etc. */
		private final Object connectLock = new Object();

		/** Address of the server we are connecting to. */
		private final InetSocketAddress serverAddress;

		private final MessageSerializer<REQ, RESP> serializer;

		/** Queue of requests while connecting. */
		private final ArrayDeque<PendingRequest> queuedRequests = new ArrayDeque<>();

		/** The established connection after the connect succeeds. */
		private EstablishedConnection established;

		/** Atomic shut down future. */
		private final AtomicReference<CompletableFuture<Void>> connectionShutdownFuture = new AtomicReference<>(null);

		/** Failure cause if something goes wrong. */
		private Throwable failureCause;

		 * Creates a pending connection to the given server.
		 * @param serverAddress Address of the server to connect to.
		private PendingConnection(
				final InetSocketAddress serverAddress,
				final MessageSerializer<REQ, RESP> serializer) {
			this.serverAddress = serverAddress;
			this.serializer = serializer;

		public void operationComplete(ChannelFuture future) throws Exception {
			if (future.isSuccess()) {
			} else {

		 * Returns a future holding the serialized request result.
		 * <p>If the channel has been established, forward the call to the
		 * established channel, otherwise queue it for when the channel is
		 * handed in.
		 * @param request the request to be sent.
		 * @return Future holding the serialized result
		CompletableFuture<RESP> sendRequest(REQ request) {
			synchronized (connectLock) {
				if (failureCause != null) {
					return FutureUtils.getFailedFuture(failureCause);
				} else if (connectionShutdownFuture.get() != null) {
					return FutureUtils.getFailedFuture(new ClosedChannelException());
				} else {
					if (established != null) {
						return established.sendRequest(request);
					} else {
						// Queue this and handle when connected
						final PendingRequest pending = new PendingRequest(request);
						return pending;

		 * Hands in a channel after a successful connection.
		 * @param channel Channel to hand in
		private void handInChannel(Channel channel) {
			synchronized (connectLock) {
				if (connectionShutdownFuture.get() != null || failureCause != null) {
					// Close the channel and we are done. Any queued requests
					// are removed on the close/failure call and after that no
					// new ones can be enqueued.
				} else {
					established = new EstablishedConnection(serverAddress, serializer, channel);

					while (!queuedRequests.isEmpty()) {
						final PendingRequest pending = queuedRequests.poll();

								(response, throwable) -> {
									if (throwable != null) {
									} else {

					// Publish the channel for the general public
					establishedConnections.put(serverAddress, established);

					// Check shut down for possible race with shut down. We
					// don't want any lingering connections after shut down,
					// which can happen if we don't check this here.
					if (clientShutdownFuture.get() != null) {
						if (establishedConnections.remove(serverAddress, established)) {

		 * Close the connecting channel with a ClosedChannelException.
		private CompletableFuture<Void> close() {
			return close(new ClosedChannelException());

		 * Close the connecting channel with an Exception (can be {@code null})
		 * or forward to the established channel.
		private CompletableFuture<Void> close(Throwable cause) {
			CompletableFuture<Void> future = new CompletableFuture<>();
			if (connectionShutdownFuture.compareAndSet(null, future)) {
				synchronized (connectLock) {
					if (failureCause == null) {
						failureCause = cause;

					if (established != null) {
						established.close().whenComplete((result, throwable) -> {
							if (throwable != null) {
							} else {
					} else {
						PendingRequest pending;
						while ((pending = queuedRequests.poll()) != null) {
			return connectionShutdownFuture.get();

		public String toString() {
			synchronized (connectLock) {
				return "PendingConnection{" +
						"serverAddress=" + serverAddress +
						", queuedRequests=" + queuedRequests.size() +
						", established=" + (established != null) +
						", closed=" + (connectionShutdownFuture.get() != null) +

		 * A pending request queued while the channel is connecting.
		private final class PendingRequest extends CompletableFuture<RESP> {

			private final REQ request;

			private PendingRequest(REQ request) {
				this.request = request;

	 * An established connection that wraps the actual channel instance and is
	 * registered at the {@link ClientHandler} for callbacks.
	private class EstablishedConnection implements ClientHandlerCallback<RESP> {

		/** Address of the server we are connected to. */
		private final InetSocketAddress serverAddress;

		/** The actual TCP channel. */
		private final Channel channel;

		/** Pending requests keyed by request ID. */
		private final ConcurrentHashMap<Long, TimestampedCompletableFuture> pendingRequests = new ConcurrentHashMap<>();

		/** Current request number used to assign unique request IDs. */
		private final AtomicLong requestCount = new AtomicLong();

		/** Atomic shut down future. */
		private final AtomicReference<CompletableFuture<Void>> connectionShutdownFuture = new AtomicReference<>(null);

		 * Creates an established connection with the given channel.
		 * @param serverAddress Address of the server connected to
		 * @param channel The actual TCP channel
				final InetSocketAddress serverAddress,
				final MessageSerializer<REQ, RESP> serializer,
				final Channel channel) {

			this.serverAddress = Preconditions.checkNotNull(serverAddress);
			this.channel = Preconditions.checkNotNull(channel);

			// Add the client handler with the callback
					getClientName() + " Handler",
					new ClientHandler<>(clientName, serializer, this)


		 * Close the channel with a ClosedChannelException.
		CompletableFuture<Void> close() {
			return close(new ClosedChannelException());

		 * Close the channel with a cause.
		 * @param cause The cause to close the channel with.
		 * @return Channel close future
		private CompletableFuture<Void> close(final Throwable cause) {
			final CompletableFuture<Void> shutdownFuture = new CompletableFuture<>();

			if (connectionShutdownFuture.compareAndSet(null, shutdownFuture)) {
				channel.close().addListener(finished -> {
					for (long requestId : pendingRequests.keySet()) {
						TimestampedCompletableFuture pending = pendingRequests.remove(requestId);
						if (pending != null && pending.completeExceptionally(cause)) {

					// when finishing, if netty successfully closes the channel, then the provided exception is used
					// as the reason for the closing. If there was something wrong at the netty side, then that exception
					// is prioritized over the provided one.
					if (finished.isSuccess()) {
					} else {
						LOG.warn("Something went wrong when trying to close connection due to : ", cause);

			// in case we had a race condition, return the winner of the race.
			return connectionShutdownFuture.get();

		 * Returns a future holding the serialized request result.
		 * @param request the request to be sent.
		 * @return Future holding the serialized result
		CompletableFuture<RESP> sendRequest(REQ request) {
			TimestampedCompletableFuture requestPromiseTs =
					new TimestampedCompletableFuture(System.nanoTime());
			try {
				final long requestId = requestCount.getAndIncrement();
				pendingRequests.put(requestId, requestPromiseTs);


				ByteBuf buf = MessageSerializer.serializeRequest(channel.alloc(), requestId, request);

				channel.writeAndFlush(buf).addListener((ChannelFutureListener) future -> {
					if (!future.isSuccess()) {
						// Fail promise if not failed to write
						TimestampedCompletableFuture pending = pendingRequests.remove(requestId);
						if (pending != null && pending.completeExceptionally(future.cause())) {

				// Check for possible race. We don't want any lingering
				// promises after a failure, which can happen if we don't check
				// this here. Note that close is treated as a failure as well.
				CompletableFuture<Void> clShutdownFuture = clientShutdownFuture.get();
				if (clShutdownFuture != null) {
					TimestampedCompletableFuture pending = pendingRequests.remove(requestId);
					if (pending != null) {
						clShutdownFuture.whenComplete((ignored, throwable) -> {
							if (throwable != null && pending.completeExceptionally(throwable)) {
							} else {
								// the shutdown future is always completed exceptionally so we should not arrive here.
								// but in any case, we complete the pending connection request exceptionally.
								pending.completeExceptionally(new ClosedChannelException());
			} catch (Throwable t) {

			return requestPromiseTs;

		public void onRequestResult(long requestId, RESP response) {
			TimestampedCompletableFuture pending = pendingRequests.remove(requestId);
			if (pending != null && !pending.isDone()) {
				long durationMillis = (System.nanoTime() - pending.getTimestamp()) / 1_000_000L;

		public void onRequestFailure(long requestId, Throwable cause) {
			TimestampedCompletableFuture pending = pendingRequests.remove(requestId);
			if (pending != null && !pending.isDone()) {

		public void onFailure(Throwable cause) {
			close(cause).handle((cancelled, ignored) -> establishedConnections.remove(serverAddress, this));

		public String toString() {
			return "EstablishedConnection{" +
					"serverAddress=" + serverAddress +
					", channel=" + channel +
					", pendingRequests=" + pendingRequests.size() +
					", requestCount=" + requestCount +

		 * Pair of promise and a timestamp.
		private class TimestampedCompletableFuture extends CompletableFuture<RESP> {

			private final long timestampInNanos;

			TimestampedCompletableFuture(long timestampInNanos) {
				this.timestampInNanos = timestampInNanos;

			public long getTimestamp() {
				return timestampInNanos;

	public boolean isEventGroupShutdown() {
		return bootstrap == null || bootstrap.group().isTerminated();