/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.nemo.runtime.executor.bytetransfer; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.group.ChannelGroup; import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferContextSetupMessage; import org.apache.nemo.runtime.common.comm.ControlMessage.ByteTransferDataDirection; import org.apache.nemo.runtime.executor.bytetransfer.ByteTransferContext.ContextId; import org.apache.nemo.runtime.executor.data.BlockManagerWorker; import org.apache.nemo.runtime.executor.data.PipeManagerWorker; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; /** * Manages multiple transport contexts for one channel. */ final class ContextManager extends SimpleChannelInboundHandler<ByteTransferContextSetupMessage> { private final PipeManagerWorker pipeManagerWorker; private final BlockManagerWorker blockManagerWorker; private final ByteTransfer byteTransfer; private final ChannelGroup channelGroup; private final String localExecutorId; private final Channel channel; private volatile String remoteExecutorId = null; private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByLocal = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByLocal = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByRemote = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByRemote = new ConcurrentHashMap<>(); private final AtomicInteger nextInputTransferIndex = new AtomicInteger(0); private final AtomicInteger nextOutputTransferIndex = new AtomicInteger(0); /** * Creates context manager for this channel. * * @param pipeManagerWorker provides handler for new contexts by remote executors * @param blockManagerWorker provides handler for new contexts by remote executors * @param byteTransfer provides channel caching * @param channelGroup to cleanup this channel when closing {@link ByteTransport} * @param localExecutorId local executor id * @param channel the {@link Channel} to manage */ ContextManager(final PipeManagerWorker pipeManagerWorker, final BlockManagerWorker blockManagerWorker, final ByteTransfer byteTransfer, final ChannelGroup channelGroup, final String localExecutorId, final Channel channel) { this.pipeManagerWorker = pipeManagerWorker; this.blockManagerWorker = blockManagerWorker; this.byteTransfer = byteTransfer; this.channelGroup = channelGroup; this.localExecutorId = localExecutorId; this.channel = channel; } /** * @return channel for this context manager. */ Channel getChannel() { return channel; } /** * Returns {@link ByteInputContext} to provide {@link io.netty.buffer.ByteBuf}s on. * * @param dataDirection the data direction * @param transferIndex transfer index * @return the {@link ByteInputContext} corresponding to the pair of {@code dataDirection} and {@code transferIndex} */ ByteInputContext getInputContext(final ByteTransferDataDirection dataDirection, final int transferIndex) { final ConcurrentMap<Integer, ByteInputContext> contexts = dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA ? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal; return contexts.get(transferIndex); } /** * Responds to new transfer contexts by a remote executor. * * @param ctx netty {@link ChannelHandlerContext} * @param message context setup message from the remote executor * @throws Exception exceptions from handler */ @Override protected void channelRead0(final ChannelHandlerContext ctx, final ByteTransferContextSetupMessage message) throws Exception { setRemoteExecutorId(message.getInitiatorExecutorId()); byteTransfer.onNewContextByRemoteExecutor(message.getInitiatorExecutorId(), channel); final ByteTransferDataDirection dataDirection = message.getDataDirection(); final int transferIndex = message.getTransferIndex(); final boolean isPipe = message.getIsPipe(); final ContextId contextId = new ContextId(remoteExecutorId, localExecutorId, dataDirection, transferIndex, isPipe); final byte[] contextDescriptor = message.getContextDescriptor().toByteArray(); if (dataDirection == ByteTransferDataDirection.INITIATOR_SENDS_DATA) { final ByteInputContext context = inputContextsInitiatedByRemote.compute(transferIndex, (index, existing) -> { if (existing != null) { throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); } return new ByteInputContext(remoteExecutorId, contextId, contextDescriptor, this); }); if (isPipe) { pipeManagerWorker.onInputContext(context); } else { blockManagerWorker.onInputContext(context); } } else { final ByteOutputContext context = outputContextsInitiatedByRemote.compute(transferIndex, (idx, existing) -> { if (existing != null) { throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); } return new ByteOutputContext(remoteExecutorId, contextId, contextDescriptor, this); }); if (isPipe) { pipeManagerWorker.onOutputContext(context); } else { blockManagerWorker.onOutputContext(context); } } } /** * Removes the specified contexts from map. * * @param context the {@link ByteTransferContext} to remove. */ void onContextExpired(final ByteTransferContext context) { final ContextId contextId = context.getContextId(); final ConcurrentMap<Integer, ? extends ByteTransferContext> contexts = context instanceof ByteInputContext ? (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA ? inputContextsInitiatedByRemote : inputContextsInitiatedByLocal) : (contextId.getDataDirection() == ByteTransferDataDirection.INITIATOR_SENDS_DATA ? outputContextsInitiatedByLocal : outputContextsInitiatedByRemote); contexts.remove(contextId.getTransferIndex(), context); } /** * Initiates a context and stores to the specified map. * * @param contexts map for storing context * @param transferIndexCounter counter for generating transfer index * @param dataDirection data direction to include in the context id * @param contextGenerator a function that returns context from context id * @param executorId id of the remote executor * @param <T> {@link ByteInputContext} or {@link ByteOutputContext} * @param isPipe is a pipe context * @return generated context */ <T extends ByteTransferContext> T newContext(final ConcurrentMap<Integer, T> contexts, final AtomicInteger transferIndexCounter, final ByteTransferDataDirection dataDirection, final Function<ContextId, T> contextGenerator, final String executorId, final boolean isPipe) { setRemoteExecutorId(executorId); final int transferIndex = transferIndexCounter.getAndIncrement(); final ContextId contextId = new ContextId(localExecutorId, executorId, dataDirection, transferIndex, isPipe); final T context = contexts.compute(transferIndex, (index, existingContext) -> { if (existingContext != null) { throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId)); } return contextGenerator.apply(contextId); }); channel.writeAndFlush(context).addListener(context.getChannelWriteListener()); return context; } /** * Create a new {@link ByteInputContext}. * * @param executorId target executor id * @param contextDescriptor the context descriptor * @param isPipe is pipe * @return new {@link ByteInputContext} */ ByteInputContext newInputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) { return newContext(inputContextsInitiatedByLocal, nextInputTransferIndex, ByteTransferDataDirection.INITIATOR_RECEIVES_DATA, contextId -> new ByteInputContext(executorId, contextId, contextDescriptor, this), executorId, isPipe); } /** * Create a new {@link ByteOutputContext}. * * @param executorId target executor id * @param contextDescriptor the context descriptor * @param isPipe is pipe * @return new {@link ByteOutputContext} */ ByteOutputContext newOutputContext(final String executorId, final byte[] contextDescriptor, final boolean isPipe) { return newContext(outputContextsInitiatedByLocal, nextOutputTransferIndex, ByteTransferDataDirection.INITIATOR_SENDS_DATA, contextId -> new ByteOutputContext(executorId, contextId, contextDescriptor, this), executorId, isPipe); } /** * Set this contest manager as connected to the specified remote executor. * * @param executorId the remote executor id */ private void setRemoteExecutorId(final String executorId) { if (remoteExecutorId == null) { remoteExecutorId = executorId; } else if (!executorId.equals(remoteExecutorId)) { throw new RuntimeException(String.format("Wrong ContextManager: (%s != %s)", executorId, remoteExecutorId)); } } @Override public void channelActive(final ChannelHandlerContext ctx) { channelGroup.add(ctx.channel()); } @Override public void channelInactive(final ChannelHandlerContext ctx) { channelGroup.remove(ctx.channel()); final Throwable cause = new Exception("Channel closed"); throwChannelErrorOnContexts(inputContextsInitiatedByLocal, cause); throwChannelErrorOnContexts(outputContextsInitiatedByLocal, cause); throwChannelErrorOnContexts(inputContextsInitiatedByRemote, cause); throwChannelErrorOnContexts(outputContextsInitiatedByRemote, cause); } /** * Invoke {@link ByteTransferContext#onChannelError(Throwable)} on the specified contexts. * * @param contexts map storing the contexts * @param cause the error * @param <T> {@link ByteInputContext} or {@link ByteOutputContext} */ private <T extends ByteTransferContext> void throwChannelErrorOnContexts(final ConcurrentMap<Integer, T> contexts, final Throwable cause) { for (final ByteTransferContext context : contexts.values()) { context.onChannelError(cause); } } }