/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.queryablestate.network; import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.queryablestate.FutureUtils; import org.apache.flink.queryablestate.network.messages.MessageBody; import org.apache.flink.queryablestate.network.messages.MessageSerializer; import org.apache.flink.queryablestate.network.stats.KvStateRequestStats; import org.apache.flink.util.Preconditions; import org.apache.flink.shaded.guava18.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption; import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup; import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; import org.apache.flink.shaded.netty4.io.netty.handler.stream.ChunkedWriteHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; /** * The base class for every client in the queryable state module. * It is using pure netty to send and receive messages of type {@link MessageBody}. * * @param <REQ> the type of request the client will send. * @param <RESP> the type of response the client expects to receive. */ @Internal public class Client<REQ extends MessageBody, RESP extends MessageBody> { private static final Logger LOG = LoggerFactory.getLogger(Client.class); /** The name of the client. Used for logging and stack traces.*/ private final String clientName; /** Netty's Bootstrap. */ private final Bootstrap bootstrap; /** The serializer to be used for (de-)serializing messages. */ private final MessageSerializer<REQ, RESP> messageSerializer; /** Statistics tracker. */ private final KvStateRequestStats stats; /** Established connections. */ private final Map<InetSocketAddress, EstablishedConnection> establishedConnections = new ConcurrentHashMap<>(); /** Pending connections. */ private final Map<InetSocketAddress, PendingConnection> pendingConnections = new ConcurrentHashMap<>(); /** Atomic shut down future. */ private final AtomicReference<CompletableFuture<Void>> clientShutdownFuture = new AtomicReference<>(null); /** * Creates a client with the specified number of event loop threads. * * @param clientName the name of the client. * @param numEventLoopThreads number of event loop threads (minimum 1). * @param serializer the serializer used to (de-)serialize messages. * @param stats the statistics collector. */ public Client( final String clientName, final int numEventLoopThreads, final MessageSerializer<REQ, RESP> serializer, final KvStateRequestStats stats) { Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); this.clientName = Preconditions.checkNotNull(clientName); this.messageSerializer = Preconditions.checkNotNull(serializer); this.stats = Preconditions.checkNotNull(stats); final ThreadFactory threadFactory = new ThreadFactoryBuilder() .setDaemon(true) .setNameFormat("Flink " + clientName + " Event Loop Thread %d") .build(); final EventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); final ByteBufAllocator bufferPool = new NettyBufferPool(numEventLoopThreads); this.bootstrap = new Bootstrap() .group(nioGroup) .channel(NioSocketChannel.class) .option(ChannelOption.ALLOCATOR, bufferPool) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel channel) throws Exception { channel.pipeline() .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast(new ChunkedWriteHandler()); } }); } public String getClientName() { return clientName; } public CompletableFuture<RESP> sendRequest(final InetSocketAddress serverAddress, final REQ request) { if (clientShutdownFuture.get() != null) { return FutureUtils.getFailedFuture(new IllegalStateException(clientName + " is already shut down.")); } EstablishedConnection connection = establishedConnections.get(serverAddress); if (connection != null) { return connection.sendRequest(request); } else { PendingConnection pendingConnection = pendingConnections.get(serverAddress); if (pendingConnection != null) { // There was a race, use the existing pending connection. return pendingConnection.sendRequest(request); } else { // We try to connect to the server. PendingConnection pending = new PendingConnection(serverAddress, messageSerializer); PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending); if (previous == null) { // OK, we are responsible to connect. bootstrap.connect(serverAddress.getAddress(), serverAddress.getPort()).addListener(pending); return pending.sendRequest(request); } else { // There was a race, use the existing pending connection. return previous.sendRequest(request); } } } } /** * Shuts down the client and closes all connections. * * <p>After a call to this method, all returned futures will be failed. * * @return A {@link CompletableFuture} that will be completed when the shutdown process is done. */ public CompletableFuture<Void> shutdown() { final CompletableFuture<Void> newShutdownFuture = new CompletableFuture<>(); if (clientShutdownFuture.compareAndSet(null, newShutdownFuture)) { final List<CompletableFuture<Void>> connectionFutures = new ArrayList<>(); for (Map.Entry<InetSocketAddress, EstablishedConnection> conn : establishedConnections.entrySet()) { if (establishedConnections.remove(conn.getKey(), conn.getValue())) { connectionFutures.add(conn.getValue().close()); } } for (Map.Entry<InetSocketAddress, PendingConnection> conn : pendingConnections.entrySet()) { if (pendingConnections.remove(conn.getKey()) != null) { connectionFutures.add(conn.getValue().close()); } } CompletableFuture.allOf( connectionFutures.toArray(new CompletableFuture<?>[connectionFutures.size()]) ).whenComplete((result, throwable) -> { if (throwable != null) { LOG.warn("Problem while shutting down the connections at the {}: {}", clientName, throwable); } if (bootstrap != null) { EventLoopGroup group = bootstrap.group(); if (group != null && !group.isShutdown()) { group.shutdownGracefully(0L, 0L, TimeUnit.MILLISECONDS) .addListener(finished -> { if (finished.isSuccess()) { newShutdownFuture.complete(null); } else { newShutdownFuture.completeExceptionally(finished.cause()); } }); } else { newShutdownFuture.complete(null); } } else { newShutdownFuture.complete(null); } }); return newShutdownFuture; } return clientShutdownFuture.get(); } /** * A pending connection that is in the process of connecting. */ private class PendingConnection implements ChannelFutureListener { /** Lock to guard the connect call, channel hand in, etc. */ private final Object connectLock = new Object(); /** Address of the server we are connecting to. */ private final InetSocketAddress serverAddress; private final MessageSerializer<REQ, RESP> serializer; /** Queue of requests while connecting. */ private final ArrayDeque<PendingRequest> queuedRequests = new ArrayDeque<>(); /** The established connection after the connect succeeds. */ private EstablishedConnection established; /** Atomic shut down future. */ private final AtomicReference<CompletableFuture<Void>> connectionShutdownFuture = new AtomicReference<>(null); /** Failure cause if something goes wrong. */ private Throwable failureCause; /** * Creates a pending connection to the given server. * * @param serverAddress Address of the server to connect to. */ private PendingConnection( final InetSocketAddress serverAddress, final MessageSerializer<REQ, RESP> serializer) { this.serverAddress = serverAddress; this.serializer = serializer; } @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { handInChannel(future.channel()); } else { close(future.cause()); } } /** * Returns a future holding the serialized request result. * * <p>If the channel has been established, forward the call to the * established channel, otherwise queue it for when the channel is * handed in. * * @param request the request to be sent. * @return Future holding the serialized result */ CompletableFuture<RESP> sendRequest(REQ request) { synchronized (connectLock) { if (failureCause != null) { return FutureUtils.getFailedFuture(failureCause); } else if (connectionShutdownFuture.get() != null) { return FutureUtils.getFailedFuture(new ClosedChannelException()); } else { if (established != null) { return established.sendRequest(request); } else { // Queue this and handle when connected final PendingRequest pending = new PendingRequest(request); queuedRequests.add(pending); return pending; } } } } /** * Hands in a channel after a successful connection. * * @param channel Channel to hand in */ private void handInChannel(Channel channel) { synchronized (connectLock) { if (connectionShutdownFuture.get() != null || failureCause != null) { // Close the channel and we are done. Any queued requests // are removed on the close/failure call and after that no // new ones can be enqueued. channel.close(); } else { established = new EstablishedConnection(serverAddress, serializer, channel); while (!queuedRequests.isEmpty()) { final PendingRequest pending = queuedRequests.poll(); established.sendRequest(pending.request).whenComplete( (response, throwable) -> { if (throwable != null) { pending.completeExceptionally(throwable); } else { pending.complete(response); } }); } // Publish the channel for the general public establishedConnections.put(serverAddress, established); pendingConnections.remove(serverAddress); // Check shut down for possible race with shut down. We // don't want any lingering connections after shut down, // which can happen if we don't check this here. if (clientShutdownFuture.get() != null) { if (establishedConnections.remove(serverAddress, established)) { established.close(); } } } } } /** * Close the connecting channel with a ClosedChannelException. */ private CompletableFuture<Void> close() { return close(new ClosedChannelException()); } /** * Close the connecting channel with an Exception (can be {@code null}) * or forward to the established channel. */ private CompletableFuture<Void> close(Throwable cause) { CompletableFuture<Void> future = new CompletableFuture<>(); if (connectionShutdownFuture.compareAndSet(null, future)) { synchronized (connectLock) { if (failureCause == null) { failureCause = cause; } if (established != null) { established.close().whenComplete((result, throwable) -> { if (throwable != null) { future.completeExceptionally(throwable); } else { future.complete(null); } }); } else { PendingRequest pending; while ((pending = queuedRequests.poll()) != null) { pending.completeExceptionally(cause); } future.complete(null); } } } return connectionShutdownFuture.get(); } @Override public String toString() { synchronized (connectLock) { return "PendingConnection{" + "serverAddress=" + serverAddress + ", queuedRequests=" + queuedRequests.size() + ", established=" + (established != null) + ", closed=" + (connectionShutdownFuture.get() != null) + '}'; } } /** * A pending request queued while the channel is connecting. */ private final class PendingRequest extends CompletableFuture<RESP> { private final REQ request; private PendingRequest(REQ request) { this.request = request; } } } /** * An established connection that wraps the actual channel instance and is * registered at the {@link ClientHandler} for callbacks. */ private class EstablishedConnection implements ClientHandlerCallback<RESP> { /** Address of the server we are connected to. */ private final InetSocketAddress serverAddress; /** The actual TCP channel. */ private final Channel channel; /** Pending requests keyed by request ID. */ private final ConcurrentHashMap<Long, TimestampedCompletableFuture> pendingRequests = new ConcurrentHashMap<>(); /** Current request number used to assign unique request IDs. */ private final AtomicLong requestCount = new AtomicLong(); /** Atomic shut down future. */ private final AtomicReference<CompletableFuture<Void>> connectionShutdownFuture = new AtomicReference<>(null); /** * Creates an established connection with the given channel. * * @param serverAddress Address of the server connected to * @param channel The actual TCP channel */ EstablishedConnection( final InetSocketAddress serverAddress, final MessageSerializer<REQ, RESP> serializer, final Channel channel) { this.serverAddress = Preconditions.checkNotNull(serverAddress); this.channel = Preconditions.checkNotNull(channel); // Add the client handler with the callback channel.pipeline().addLast( getClientName() + " Handler", new ClientHandler<>(clientName, serializer, this) ); stats.reportActiveConnection(); } /** * Close the channel with a ClosedChannelException. */ CompletableFuture<Void> close() { return close(new ClosedChannelException()); } /** * Close the channel with a cause. * * @param cause The cause to close the channel with. * @return Channel close future */ private CompletableFuture<Void> close(final Throwable cause) { final CompletableFuture<Void> shutdownFuture = new CompletableFuture<>(); if (connectionShutdownFuture.compareAndSet(null, shutdownFuture)) { channel.close().addListener(finished -> { stats.reportInactiveConnection(); for (long requestId : pendingRequests.keySet()) { TimestampedCompletableFuture pending = pendingRequests.remove(requestId); if (pending != null && pending.completeExceptionally(cause)) { stats.reportFailedRequest(); } } // when finishing, if netty successfully closes the channel, then the provided exception is used // as the reason for the closing. If there was something wrong at the netty side, then that exception // is prioritized over the provided one. if (finished.isSuccess()) { shutdownFuture.completeExceptionally(cause); } else { LOG.warn("Something went wrong when trying to close connection due to : ", cause); shutdownFuture.completeExceptionally(finished.cause()); } }); } // in case we had a race condition, return the winner of the race. return connectionShutdownFuture.get(); } /** * Returns a future holding the serialized request result. * @param request the request to be sent. * @return Future holding the serialized result */ CompletableFuture<RESP> sendRequest(REQ request) { TimestampedCompletableFuture requestPromiseTs = new TimestampedCompletableFuture(System.nanoTime()); try { final long requestId = requestCount.getAndIncrement(); pendingRequests.put(requestId, requestPromiseTs); stats.reportRequest(); ByteBuf buf = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); channel.writeAndFlush(buf).addListener((ChannelFutureListener) future -> { if (!future.isSuccess()) { // Fail promise if not failed to write TimestampedCompletableFuture pending = pendingRequests.remove(requestId); if (pending != null && pending.completeExceptionally(future.cause())) { stats.reportFailedRequest(); } } }); // Check for possible race. We don't want any lingering // promises after a failure, which can happen if we don't check // this here. Note that close is treated as a failure as well. CompletableFuture<Void> clShutdownFuture = clientShutdownFuture.get(); if (clShutdownFuture != null) { TimestampedCompletableFuture pending = pendingRequests.remove(requestId); if (pending != null) { clShutdownFuture.whenComplete((ignored, throwable) -> { if (throwable != null && pending.completeExceptionally(throwable)) { stats.reportFailedRequest(); } else { // the shutdown future is always completed exceptionally so we should not arrive here. // but in any case, we complete the pending connection request exceptionally. pending.completeExceptionally(new ClosedChannelException()); } }); } } } catch (Throwable t) { requestPromiseTs.completeExceptionally(t); } return requestPromiseTs; } @Override public void onRequestResult(long requestId, RESP response) { TimestampedCompletableFuture pending = pendingRequests.remove(requestId); if (pending != null && !pending.isDone()) { long durationMillis = (System.nanoTime() - pending.getTimestamp()) / 1_000_000L; stats.reportSuccessfulRequest(durationMillis); pending.complete(response); } } @Override public void onRequestFailure(long requestId, Throwable cause) { TimestampedCompletableFuture pending = pendingRequests.remove(requestId); if (pending != null && !pending.isDone()) { stats.reportFailedRequest(); pending.completeExceptionally(cause); } } @Override public void onFailure(Throwable cause) { close(cause).handle((cancelled, ignored) -> establishedConnections.remove(serverAddress, this)); } @Override public String toString() { return "EstablishedConnection{" + "serverAddress=" + serverAddress + ", channel=" + channel + ", pendingRequests=" + pendingRequests.size() + ", requestCount=" + requestCount + '}'; } /** * Pair of promise and a timestamp. */ private class TimestampedCompletableFuture extends CompletableFuture<RESP> { private final long timestampInNanos; TimestampedCompletableFuture(long timestampInNanos) { this.timestampInNanos = timestampInNanos; } public long getTimestamp() { return timestampInNanos; } } } @VisibleForTesting public boolean isEventGroupShutdown() { return bootstrap == null || bootstrap.group().isTerminated(); } }