/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.shuffle.rdma; import com.ibm.disni.rdma.verbs.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.List; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.Semaphore; import java.util.concurrent.ConcurrentHashMap; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; public class RdmaChannel { private static final Logger logger = LoggerFactory.getLogger(RdmaChannel.class); private static final int MAX_ACK_COUNT = 4; private static final int POLL_CQ_LIST_SIZE = 16; private static final int ZERO_SIZED_RECV_WR_LIST_SIZE = 16; private static final AtomicInteger idGenerator = new AtomicInteger(0); private final int id = idGenerator.getAndIncrement(); private final ConcurrentHashMap<Integer, ConcurrentLinkedDeque<SVCPostSend>> svcPostSendCache = new ConcurrentHashMap(); enum RdmaChannelType { RPC, RDMA_READ_REQUESTOR, RDMA_READ_RESPONDER } private final RdmaChannelType rdmaChannelType; private final RdmaCompletionListener receiveListener; private final RdmaBufferManager rdmaBufferManager; private IbvCompChannel compChannel = null; private RdmaEventChannel eventChannel = null; private final int rdmaCmEventTimeout; private final int teardownListenTimeout; private final int resolvePathTimeout; private RdmaCmId cmId = null; private IbvCQ cq = null; private IbvQP qp = null; private final AtomicBoolean isStopped = new AtomicBoolean(false); // Send a credit report on every (recvDepth / RECV_CREDIT_REPORT_RATIO) receive credit reclaims private static final int RECV_CREDIT_REPORT_RATIO = 8; private Semaphore remoteRecvCredits; private int localRecvCreditsPendingReport = 0; private Semaphore sendBudgetSemaphore; private final ConcurrentLinkedDeque<PendingSend> sendWrQueue = new ConcurrentLinkedDeque<>(); private class PendingSend { final LinkedList<IbvSendWR> ibvSendWRList; final int recvCreditsNeeded; PendingSend(LinkedList<IbvSendWR> ibvSendWRList, int recvCreditsNeeded) { this.ibvSendWRList = ibvSendWRList; this.recvCreditsNeeded = recvCreditsNeeded; } } private class PostRecvWr { final IbvRecvWR ibvRecvWR; final RdmaBuffer rdmaBuf; final ByteBuffer buf; PostRecvWr(IbvRecvWR ibvRecvWR, RdmaBuffer rdmaBuf) throws IOException { this.ibvRecvWR = ibvRecvWR; this.rdmaBuf = rdmaBuf; this.buf = rdmaBuf.getByteBuffer(); } } private PostRecvWr[] postRecvWrArray = null; private int ackCounter = 0; private final int sendDepth; private final int recvDepth; private final int recvWrSize; private LinkedList<IbvRecvWR> zeroSizeRecvWrList; private boolean isWarnedOnSendOverSubscription = false; private final int cpuVector; private SVCReqNotify reqNotifyCall; private SVCPollCq svcPollCq; private IbvWC[] ibvWCs; private RdmaThread rdmaThread = null; enum RdmaChannelState { IDLE, CONNECTING, CONNECTED, ERROR } private final AtomicInteger rdmaChannelState = new AtomicInteger(RdmaChannelState.IDLE.ordinal()); private void setRdmaChannelState(RdmaChannelState newRdmaChannelState) { // Allow to change the channel state only if not in ERROR rdmaChannelState.updateAndGet(state -> state != RdmaChannelState.ERROR.ordinal() ? newRdmaChannelState.ordinal() : state); } private class CompletionInfo { final RdmaCompletionListener listener; final int sendPermitsToReclaim; CompletionInfo(RdmaCompletionListener listener, int sendPermitsToReclaim) { this.listener = listener; this.sendPermitsToReclaim = sendPermitsToReclaim; } } private final ConcurrentHashMap<Integer, CompletionInfo> completionInfoMap = new ConcurrentHashMap<>(); // NOOP_RESERVED_INDEX is used for send operations that do not require a callback private static final int NOOP_RESERVED_INDEX = 0; private final AtomicInteger completionInfoIndex = new AtomicInteger(NOOP_RESERVED_INDEX); private final RdmaShuffleConf conf; RdmaChannel( RdmaChannelType rdmaChannelType, RdmaShuffleConf conf, RdmaBufferManager rdmaBufferManager, RdmaCompletionListener receiveListener, RdmaCmId cmId, int cpuVector) { this(rdmaChannelType, conf, rdmaBufferManager, receiveListener, cpuVector); this.cmId = cmId; } RdmaChannel( RdmaChannelType rdmaChannelType, RdmaShuffleConf conf, RdmaBufferManager rdmaBufferManager, RdmaCompletionListener receiveListener, int cpuVector) { this.rdmaChannelType = rdmaChannelType; this.receiveListener = receiveListener; this.rdmaBufferManager = rdmaBufferManager; this.cpuVector = cpuVector; this.conf = conf; switch (rdmaChannelType) { case RPC: // Single bidirectional QP between executors and driver. if (conf.swFlowControl()) { this.remoteRecvCredits = new Semaphore( conf.recvQueueDepth() - RECV_CREDIT_REPORT_RATIO, false); } this.recvDepth = conf.recvQueueDepth(); this.recvWrSize = conf.recvWrSize(); this.sendDepth = conf.sendQueueDepth(); this.sendBudgetSemaphore = new Semaphore(sendDepth - RECV_CREDIT_REPORT_RATIO, false); break; case RDMA_READ_REQUESTOR: // Requires sends only, no need for any receives this.recvDepth = 0; this.recvWrSize = 0; this.sendDepth = conf.sendQueueDepth(); this.sendBudgetSemaphore = new Semaphore(sendDepth, false); break; case RDMA_READ_RESPONDER: // Doesn't require sends nor receives this.recvDepth = 0; this.recvWrSize = 0; this.sendDepth = 0; break; default: throw new IllegalArgumentException("Illegal RdmaChannelType"); } this.rdmaCmEventTimeout = conf.rdmaCmEventTimeout(); this.teardownListenTimeout = conf.teardownListenTimeout(); this.resolvePathTimeout = conf.resolvePathTimeout(); } private int putCompletionInfo(CompletionInfo completionInfo) { int index; do { index = completionInfoIndex.incrementAndGet(); } while (index == NOOP_RESERVED_INDEX); CompletionInfo retCompletionInfo = completionInfoMap.put(index, completionInfo); if (retCompletionInfo != null) { throw new RuntimeException("Overflow of CompletionInfos"); } return index; } private CompletionInfo removeCompletionInfo(int index) { return completionInfoMap.remove(index); } private void setupCommon() throws IOException { IbvContext ibvContext = cmId.getVerbs(); if (ibvContext == null) { throw new IOException("Failed to retrieve IbvContext"); } compChannel = ibvContext.createCompChannel(); if (compChannel == null) { throw new IOException("createCompChannel() failed"); } // ncqe must be greater than 1 cq = ibvContext.createCQ(compChannel, (sendDepth + recvDepth) > 0 ? sendDepth + recvDepth : 1, cpuVector); if (cq == null) { throw new IOException("createCQ() failed"); } reqNotifyCall = cq.reqNotification(false); reqNotifyCall.execute(); ibvWCs = new IbvWC[POLL_CQ_LIST_SIZE]; for (int i = 0; i < POLL_CQ_LIST_SIZE; i++) { ibvWCs[i] = new IbvWC(); } svcPollCq = cq.poll(ibvWCs, POLL_CQ_LIST_SIZE); IbvQPInitAttr attr = new IbvQPInitAttr(); attr.setQp_type(IbvQP.IBV_QPT_RC); attr.setSend_cq(cq); attr.setRecv_cq(cq); attr.cap().setMax_recv_sge(1); attr.cap().setMax_recv_wr(recvDepth); attr.cap().setMax_send_sge(1); attr.cap().setMax_send_wr(sendDepth); qp = cmId.createQP(rdmaBufferManager.getPd(), attr); if (qp == null) { throw new IOException("createQP() failed"); } if (recvWrSize == 0) { initZeroSizeRecvs(); } else { initRecvs(); } rdmaThread = new RdmaThread(this, cpuVector); rdmaThread.start(); } void connect(InetSocketAddress socketAddress) throws IOException { eventChannel = RdmaEventChannel.createEventChannel(); if (eventChannel == null) { throw new IOException("createEventChannel() failed"); } // Create an active connect cm id cmId = eventChannel.createId(RdmaCm.RDMA_PS_TCP); if (cmId == null) { throw new IOException("createId() failed"); } // Resolve the addr setRdmaChannelState(RdmaChannelState.CONNECTING); int err = cmId.resolveAddr(null, socketAddress, resolvePathTimeout); if (err != 0) { throw new IOException("resolveAddr() failed: " + err); } processRdmaCmEvent(RdmaCmEvent.EventType.RDMA_CM_EVENT_ADDR_RESOLVED.ordinal(), rdmaCmEventTimeout); // Resolve the route err = cmId.resolveRoute(resolvePathTimeout); if (err != 0) { throw new IOException("resolveRoute() failed: " + err); } processRdmaCmEvent(RdmaCmEvent.EventType.RDMA_CM_EVENT_ROUTE_RESOLVED.ordinal(), rdmaCmEventTimeout); setupCommon(); RdmaConnParam connParams = new RdmaConnParam(); // TODO: current disni code does not support setting these // connParams.setInitiator_depth((byte) 16); // connParams.setResponder_resources((byte) 16); // retry infinite connParams.setRetry_count((byte) 7); connParams.setRnr_retry_count((byte) 7); err = cmId.connect(connParams); if (err != 0) { setRdmaChannelState(RdmaChannelState.ERROR); throw new IOException("connect() failed"); } processRdmaCmEvent(RdmaCmEvent.EventType.RDMA_CM_EVENT_ESTABLISHED.ordinal(), rdmaCmEventTimeout); setRdmaChannelState(RdmaChannelState.CONNECTED); } InetSocketAddress getSourceSocketAddress() throws IOException { return (InetSocketAddress)cmId.getSource(); } void accept() throws IOException { RdmaConnParam connParams = new RdmaConnParam(); setupCommon(); // TODO: current disni code does not support setting these //connParams.setInitiator_depth((byte) 16); //connParams.setResponder_resources((byte) 16); // retry infinite connParams.setRetry_count((byte) 7); connParams.setRnr_retry_count((byte) 7); setRdmaChannelState(RdmaChannelState.CONNECTING); int err = cmId.accept(connParams); if (err != 0) { setRdmaChannelState(RdmaChannelState.ERROR); throw new IOException("accept() failed"); } } void finalizeConnection() { setRdmaChannelState(RdmaChannelState.CONNECTED); synchronized (rdmaChannelState) { rdmaChannelState.notifyAll(); } } private void processRdmaCmEvent(int expectedEvent, int timeout) throws IOException { RdmaCmEvent event = eventChannel.getCmEvent(timeout); if (event == null) { setRdmaChannelState(RdmaChannelState.ERROR); throw new IOException("getCmEvent() failed"); } int eventType = event.getEvent(); event.ackEvent(); if (eventType != expectedEvent) { setRdmaChannelState(RdmaChannelState.ERROR); throw new IOException("Received CM event: " + RdmaCmEvent.EventType.values()[eventType] + " but expected: " + RdmaCmEvent.EventType.values()[expectedEvent]); } } @SuppressWarnings({"checkstyle:EmptyCatchBlock"}) void waitForActiveConnection() { synchronized (rdmaChannelState) { try { rdmaChannelState.wait(100); } catch (InterruptedException ignored) { } } } private void rdmaPostWRList(LinkedList<IbvSendWR> sendWRList) throws IOException { if (isError() || isStopped.get()) { throw new IOException("QP is in error state, can't post new requests"); } ConcurrentLinkedDeque<SVCPostSend> stack; SVCPostSend svcPostSendObject; int numWrElements = sendWRList.size(); // Special case for 0 sgeElements when rdmaSendWithImm if (sendWRList.size() == 1 && sendWRList.getFirst().getNum_sge() == 0) { numWrElements = NOOP_RESERVED_INDEX; } stack = svcPostSendCache.computeIfAbsent(numWrElements, numElements -> new ConcurrentLinkedDeque<>()); // To avoid buffer allocations in disni update cached SVCPostSendObject if (sendWRList.getFirst().getOpcode() == IbvSendWR.IbvWrOcode.IBV_WR_RDMA_READ.ordinal() && (svcPostSendObject = stack.pollFirst()) != null) { int i = 0; for (IbvSendWR sendWr: sendWRList) { SVCPostSend.SendWRMod sendWrMod = svcPostSendObject.getWrMod(i); sendWrMod.setWr_id(sendWr.getWr_id()); sendWrMod.setSend_flags(sendWr.getSend_flags()); // Setting up RDMA attributes sendWrMod.getRdmaMod().setRemote_addr(sendWr.getRdma().getRemote_addr()); sendWrMod.getRdmaMod().setRkey(sendWr.getRdma().getRkey()); sendWrMod.getRdmaMod().setReserved(sendWr.getRdma().getReserved()); if (sendWr.getNum_sge() == 1) { IbvSge sge = sendWr.getSge(0); sendWrMod.getSgeMod(0).setLkey(sge.getLkey()); sendWrMod.getSgeMod(0).setAddr(sge.getAddr()); sendWrMod.getSgeMod(0).setLength(sge.getLength()); } i++; } } else { svcPostSendObject = qp.postSend(sendWRList, null); } svcPostSendObject.execute(); // Cache SVCPostSend objects only for RDMA Read requests if (sendWRList.getFirst().getOpcode() == IbvSendWR.IbvWrOcode.IBV_WR_RDMA_READ.ordinal()) { stack.add(svcPostSendObject); } else { svcPostSendObject.free(); } } private void rdmaPostWRListInQueue(PendingSend pendingSend) throws IOException { if (isError() || isStopped.get()) { throw new IOException("QP is in error state, can't post new requests"); } if (sendBudgetSemaphore.tryAcquire(pendingSend.ibvSendWRList.size())) { // Ordering is lost here since if there are credits avail they will be immediately utilized // without fairness. We don't care about fairness, since Spark doesn't expect the requests to // complete in a particular order if (pendingSend.recvCreditsNeeded > 0 && remoteRecvCredits != null && !remoteRecvCredits.tryAcquire(pendingSend.recvCreditsNeeded)) { sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size()); sendWrQueue.add(pendingSend); } else { try { rdmaPostWRList(pendingSend.ibvSendWRList); } catch (Exception e) { if (remoteRecvCredits != null) { remoteRecvCredits.release(pendingSend.recvCreditsNeeded); } sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size()); sendWrQueue.add(pendingSend); throw e; } } } else { if (!isWarnedOnSendOverSubscription) { isWarnedOnSendOverSubscription = true; logger.warn(this + " oversubscription detected. RDMA" + " send queue depth is too small. To improve performance, please set" + " spark.shuffle.rdma.sendQueueDepth to a higher value (current depth: " + sendDepth); } sendWrQueue.add(pendingSend); // Try again, in case it is the only WR in the queue and there are no pending sends if (sendBudgetSemaphore.tryAcquire(pendingSend.ibvSendWRList.size())) { if (sendWrQueue.remove(pendingSend)) { if (pendingSend.recvCreditsNeeded > 0 && remoteRecvCredits != null && !remoteRecvCredits.tryAcquire(pendingSend.recvCreditsNeeded)) { sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size()); sendWrQueue.add(pendingSend); } else { try { rdmaPostWRList(pendingSend.ibvSendWRList); } catch (Exception e) { if (remoteRecvCredits != null) { remoteRecvCredits.release(pendingSend.recvCreditsNeeded); } sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size()); sendWrQueue.add(pendingSend); throw e; } } } else { sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size()); } } } } void rdmaReadInQueue(RdmaCompletionListener listener, long localAddress, int lKey, int[] sizes, long[] remoteAddresses, int[] rKeys) throws IOException { long offset = 0; LinkedList<IbvSendWR> readWRList = new LinkedList<>(); for (int i = 0; i < remoteAddresses.length; i++) { IbvSge readSge = new IbvSge(); readSge.setAddr(localAddress + offset); readSge.setLength(sizes[i]); readSge.setLkey(lKey); offset += sizes[i]; LinkedList<IbvSge> readSgeList = new LinkedList<>(); readSgeList.add(readSge); IbvSendWR readWr = new IbvSendWR(); readWr.setOpcode(IbvSendWR.IbvWrOcode.IBV_WR_RDMA_READ.ordinal()); readWr.setSg_list(readSgeList); readWr.getRdma().setRemote_addr(remoteAddresses[i]); readWr.getRdma().setRkey(rKeys[i]); readWRList.add(readWr); } readWRList.getLast().setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); int completionInfoId = putCompletionInfo(new CompletionInfo(listener, remoteAddresses.length)); readWRList.getLast().setWr_id(completionInfoId); try { rdmaPostWRListInQueue(new PendingSend(readWRList, 0)); } catch (Exception e) { removeCompletionInfo(completionInfoId); throw e; } } /** * RDMA write buffer(localAddress, localLength, lKey) to remote buffer at remoteAddress * @param listener * @param localAddress * @param localLength * @param lKey * @param remoteAddress * @param rKey * @throws IOException */ public void rdmaWriteInQueue(RdmaCompletionListener listener, long localAddress, int localLength, int lKey, long remoteAddress, int rKey) throws IOException { LinkedList<IbvSendWR> writeWRList = new LinkedList<>(); IbvSge writeSge = new IbvSge(); writeSge.setAddr(localAddress); writeSge.setLength(localLength); writeSge.setLkey(lKey); LinkedList<IbvSge> writeSgeList = new LinkedList<>(); writeSgeList.add(writeSge); IbvSendWR writeWr = new IbvSendWR(); writeWr.setOpcode(IbvSendWR.IbvWrOcode.IBV_WR_RDMA_WRITE.ordinal()); writeWr.setSg_list(writeSgeList); writeWr.getRdma().setRemote_addr(remoteAddress); writeWr.getRdma().setRkey(rKey); writeWr.setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); writeWRList.add(writeWr); int completionInfoId = putCompletionInfo(new CompletionInfo(listener, 1)); writeWRList.getLast().setWr_id(completionInfoId); try { rdmaPostWRListInQueue(new PendingSend(writeWRList, 0)); } catch (Exception e) { removeCompletionInfo(completionInfoId); throw e; } } public void rdmaSendInQueue(RdmaCompletionListener listener, long[] localAddresses, int[] lKeys, int[] sizes) throws IOException { LinkedList<IbvSendWR> sendWRList = new LinkedList<>(); for (int i = 0; i < localAddresses.length; i++) { IbvSge sendSge = new IbvSge(); sendSge.setAddr(localAddresses[i]); sendSge.setLength(sizes[i]); sendSge.setLkey(lKeys[i]); LinkedList<IbvSge> sendSgeList = new LinkedList<>(); sendSgeList.add(sendSge); IbvSendWR sendWr = new IbvSendWR(); sendWr.setOpcode(IbvSendWR.IbvWrOcode.IBV_WR_SEND.ordinal()); sendWr.setSg_list(sendSgeList); sendWRList.add(sendWr); } sendWRList.getLast().setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); int completionInfoId = putCompletionInfo(new CompletionInfo(listener, localAddresses.length)); sendWRList.getLast().setWr_id(completionInfoId); try { rdmaPostWRListInQueue(new PendingSend(sendWRList, sendWRList.size())); } catch (Exception e) { removeCompletionInfo(completionInfoId); throw e; } } // Used only for sending a receive credit report private void rdmaSendWithImm(int immData) throws IOException { LinkedList<IbvSendWR> sendWRList = new LinkedList<>(); LinkedList<IbvSge> sendSgeList = new LinkedList<>(); IbvSendWR sendWr = new IbvSendWR(); sendWr.setOpcode(IbvSendWR.IbvWrOcode.IBV_WR_RDMA_WRITE_WITH_IMM.ordinal()); sendWr.setImm_data(immData); sendWr.setSg_list(sendSgeList); sendWr.setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); sendWr.setWr_id(NOOP_RESERVED_INDEX); // doesn't require a callback sendWRList.add(sendWr); rdmaPostWRList(sendWRList); } private void initZeroSizeRecvs() throws IOException { if (recvDepth == 0) { return; } IbvRecvWR wr = new IbvRecvWR(); wr.setWr_id(recvDepth); wr.setNum_sge(0); zeroSizeRecvWrList = new LinkedList<>(); for (int i = 0; i < ZERO_SIZED_RECV_WR_LIST_SIZE; i++) { zeroSizeRecvWrList.add(wr); } postZeroSizeRecvWrs(recvDepth); } private void postZeroSizeRecvWrs(int count) throws IOException { if (isError() || isStopped.get() || recvDepth == 0) { return; } int cPosted = 0; List<IbvRecvWR> actualRecvWrList = zeroSizeRecvWrList; while (cPosted < count) { int cCurrentPost = ZERO_SIZED_RECV_WR_LIST_SIZE; if (count - cPosted < ZERO_SIZED_RECV_WR_LIST_SIZE) { actualRecvWrList = zeroSizeRecvWrList.subList(0, count - cPosted); cCurrentPost = count - cPosted; } SVCPostRecv svcPostRecv = qp.postRecv(actualRecvWrList, null); svcPostRecv.execute(); svcPostRecv.free(); cPosted += cCurrentPost; } } private void postRecvWrs(int startIndex, int count) throws IOException { if (isError() || isStopped.get() || recvDepth == 0) { return; } LinkedList<IbvRecvWR> recvWrList = new LinkedList<>(); for (int i = startIndex; i < startIndex + count; i++) { postRecvWrArray[i % recvDepth].buf.clear(); postRecvWrArray[i % recvDepth].buf.limit(recvWrSize); recvWrList.add(postRecvWrArray[i % recvDepth].ibvRecvWR); } SVCPostRecv svcPostRecv = qp.postRecv(recvWrList, null); svcPostRecv.execute(); svcPostRecv.free(); } private void initRecvs() throws IOException { if (isError() || isStopped.get() || recvDepth == 0) { return; } postRecvWrArray = new PostRecvWr[recvDepth]; LinkedList<IbvRecvWR> recvWrList = new LinkedList<>(); for (int i = 0; i < recvDepth; i++) { RdmaBuffer rdmaBuffer = rdmaBufferManager.get(recvWrSize); IbvSge sge = new IbvSge(); sge.setAddr(rdmaBuffer.getAddress()); sge.setLength(rdmaBuffer.getLength()); sge.setLkey(rdmaBuffer.getLkey()); LinkedList<IbvSge> sgeList = new LinkedList<>(); sgeList.add(sge); IbvRecvWR wr = new IbvRecvWR(); wr.setWr_id(i); wr.setSg_list(sgeList); postRecvWrArray[i] = new PostRecvWr(wr, rdmaBuffer); recvWrList.add(wr); } SVCPostRecv svcPostRecv = qp.postRecv(recvWrList, null); svcPostRecv.execute(); svcPostRecv.free(); } private void exhaustCq() throws IOException { int reclaimedSendPermits = 0; int reclaimedRecvWrs = 0; int firstRecvWrIndex = -1; while (true) { int res = svcPollCq.execute().getPolls(); if (res < 0) { logger.error("PollCQ failed executing with res: " + res); break; } else if (res > 0) { for (int i = 0; i < res; i++) { boolean wcSuccess = ibvWCs[i].getStatus() == IbvWC.IbvWcStatus.IBV_WC_SUCCESS.ordinal(); if (!wcSuccess && !isError()) { setRdmaChannelState(RdmaChannelState.ERROR); logger.error("Completion with error: " + IbvWC.IbvWcStatus.values()[ibvWCs[i].getStatus()].name()); } if (ibvWCs[i].getOpcode() == IbvWC.IbvWcOpcode.IBV_WC_SEND.getOpcode() || ibvWCs[i].getOpcode() == IbvWC.IbvWcOpcode.IBV_WC_RDMA_WRITE.getOpcode() || ibvWCs[i].getOpcode() == IbvWC.IbvWcOpcode.IBV_WC_RDMA_READ.getOpcode()) { int completionInfoId = (int)ibvWCs[i].getWr_id(); if (completionInfoId != NOOP_RESERVED_INDEX) { CompletionInfo completionInfo = removeCompletionInfo(completionInfoId); if (completionInfo != null) { if (wcSuccess) { completionInfo.listener.onSuccess(null); } else { completionInfo.listener.onFailure( new IOException("RDMA Send/Write/Read WR completed with error: " + IbvWC.IbvWcStatus.values()[ibvWCs[i].getStatus()].name())); } reclaimedSendPermits += completionInfo.sendPermitsToReclaim; } else if (wcSuccess) { // Ignore the case of error, as the listener will be invoked by the last WC logger.warn("Couldn't find CompletionInfo with index: " + completionInfoId); } } } else if (ibvWCs[i].getOpcode() == IbvWC.IbvWcOpcode.IBV_WC_RECV.getOpcode()) { int recvWrId = (int)ibvWCs[i].getWr_id(); if (firstRecvWrIndex == -1) { firstRecvWrIndex = recvWrId; } if (wcSuccess) { if (recvWrSize > 0) { receiveListener.onSuccess(postRecvWrArray[recvWrId].buf); } else { receiveListener.onSuccess(null); } } else { receiveListener.onFailure( new IOException(this + "RDMA Receive WR completed with error: " + IbvWC.IbvWcStatus.values()[ibvWCs[i].getStatus()])); } reclaimedRecvWrs += 1; } else if (ibvWCs[i].getOpcode() == IbvWC.IbvWcOpcode.IBV_WC_RECV_RDMA_WITH_IMM.getOpcode()) { // Receive credit report - update new credits if (remoteRecvCredits != null) { remoteRecvCredits.release(ibvWCs[i].getImm_data()); } int recvWrId = (int)ibvWCs[i].getWr_id(); if (firstRecvWrIndex == -1) { firstRecvWrIndex = recvWrId; } reclaimedRecvWrs += 1; } else { logger.error(this + "Unexpected opcode in PollCQ: " + ibvWCs[i].getOpcode()); } } } else { break; } } if (isError()) { throw new IOException(this + "QP entered ERROR state"); } if (reclaimedRecvWrs > 0) { if (recvWrSize > 0) { postRecvWrs(firstRecvWrIndex, reclaimedRecvWrs); } else { postZeroSizeRecvWrs(reclaimedRecvWrs); } } if (conf.swFlowControl() && rdmaChannelType == RdmaChannelType.RPC) { // Software-level flow control is enabled localRecvCreditsPendingReport += reclaimedRecvWrs; if (localRecvCreditsPendingReport > (recvDepth / RECV_CREDIT_REPORT_RATIO)) { // Send a credit report once (recvDepth / RECV_CREDIT_REPORT_RATIO) were accumulated try { rdmaSendWithImm(localRecvCreditsPendingReport); } catch (IOException ioe) { logger.warn(this + " Failed to send a receive credit report with exception: " + ioe + " failing silently."); } localRecvCreditsPendingReport = 0; } } // Drain pending sends queue while (sendBudgetSemaphore != null && !isStopped.get() && !isError()) { PendingSend pendingSend = sendWrQueue.poll(); if (pendingSend != null) { // If there are not enough available permits from // this run AND from the semaphore, then it means that there are // more completions coming and they will exhaust the queue later if (pendingSend.ibvSendWRList.size() > reclaimedSendPermits) { if (!sendBudgetSemaphore.tryAcquire( pendingSend.ibvSendWRList.size() - reclaimedSendPermits)) { sendWrQueue.push(pendingSend); sendBudgetSemaphore.release(reclaimedSendPermits); break; } else { if (pendingSend.recvCreditsNeeded > 0 && remoteRecvCredits != null && !remoteRecvCredits.tryAcquire(pendingSend.recvCreditsNeeded)) { sendWrQueue.push(pendingSend); sendBudgetSemaphore.release(pendingSend.ibvSendWRList.size() + reclaimedSendPermits); break; } else { reclaimedSendPermits = 0; } } } else { if (pendingSend.recvCreditsNeeded > 0 && remoteRecvCredits != null && !remoteRecvCredits.tryAcquire(pendingSend.recvCreditsNeeded)) { sendWrQueue.push(pendingSend); sendBudgetSemaphore.release(reclaimedSendPermits); break; } else { reclaimedSendPermits -= pendingSend.ibvSendWRList.size(); } } try { rdmaPostWRList(pendingSend.ibvSendWRList); } catch (IOException e) { setRdmaChannelState(RdmaChannelState.ERROR); // reclaim the credit and put sendWRList back to the queue // however, the channel/QP is already broken and more actions // needed to be taken to recover reclaimedSendPermits += pendingSend.ibvSendWRList.size(); if (remoteRecvCredits != null) { remoteRecvCredits.release(pendingSend.recvCreditsNeeded); } sendWrQueue.push(pendingSend); sendBudgetSemaphore.release(reclaimedSendPermits); break; } } else { sendBudgetSemaphore.release(reclaimedSendPermits); break; } } } boolean processCompletions() throws IOException { // Disni's API uses a CQ here, which is wrong boolean success = compChannel.getCqEvent(cq, 50); if (success) { ackCounter++; if (ackCounter == MAX_ACK_COUNT) { cq.ackEvents(ackCounter); ackCounter = 0; } if (!isStopped.get()) { reqNotifyCall.execute(); } exhaustCq(); return true; } else if (isStopped.get() && ackCounter > 0) { cq.ackEvents(ackCounter); ackCounter = 0; } return false; } void stop() throws InterruptedException, IOException { if (!isStopped.getAndSet(true)) { logger.info("Stopping RdmaChannel " + this); if (rdmaThread != null) rdmaThread.stop(); // Fail pending completionInfos for (Integer completionInfoId: completionInfoMap.keySet()) { final CompletionInfo completionInfo = completionInfoMap.remove(completionInfoId); if (completionInfo != null) { completionInfo.listener.onFailure( new IOException("RDMA Send/Read WR revoked since QP was removed")); } } if (cmId != null) { int ret = cmId.disconnect(); if (ret != 0) { logger.error("disconnect failed with errno: " + ret); } else if (rdmaChannelType.equals(RdmaChannelType.RPC) || rdmaChannelType.equals(RdmaChannelType.RDMA_READ_REQUESTOR)) { try { processRdmaCmEvent(RdmaCmEvent.EventType.RDMA_CM_EVENT_DISCONNECTED.ordinal(), teardownListenTimeout); } catch (IOException e) { logger.warn("Failed to get RDMA_CM_EVENT_DISCONNECTED: " + e.getLocalizedMessage()); } } if (qp != null) { ret = cmId.destroyQP(); if (ret != 0) { logger.error("destroyQP failed with errno: " + ret); } } } if (recvWrSize > 0 && postRecvWrArray != null) { for (int i = 0; i < recvDepth; i++) { if (postRecvWrArray[i] != null) { rdmaBufferManager.put(postRecvWrArray[i].rdmaBuf); } } } if (reqNotifyCall != null) { reqNotifyCall.free(); } if (svcPollCq != null) { svcPollCq.free(); } if (cq != null) { if (ackCounter > 0) { cq.ackEvents(ackCounter); } int ret = cq.destroyCQ(); if (ret != 0) { logger.error("destroyCQ failed with errno: " + ret); } } if (cmId != null) { int ret = cmId.destroyId(); if (ret != 0) { logger.error("destroyId failed with errno: " + ret); } } if (compChannel != null) { int ret = compChannel.destroyCompChannel(); if (ret != 0) { logger.error("destroyCompChannel failed with errno: " + ret); } } if (eventChannel != null) { int ret = eventChannel.destroyEventChannel(); if (ret != 0) { logger.error("destroyEventChannel failed with errno: " + ret); } } } } boolean isConnected() { return rdmaChannelState.get() == RdmaChannelState.CONNECTED.ordinal(); } boolean isError() { return rdmaChannelState.get() == RdmaChannelState.ERROR.ordinal(); } @Override public String toString() { return "RdmaChannel(" + id + ") "; } }