/** * * Copyright (c) 2017 ytk-mp4j https://github.com/yuantiku * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package com.fenbi.mp4j.comm; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.fenbi.mp4j.exception.Mp4jException; import com.fenbi.mp4j.meta.ArrayMetaData; import com.fenbi.mp4j.meta.MapMetaData; import com.fenbi.mp4j.meta.MetaData; import com.fenbi.mp4j.operand.*; import com.fenbi.mp4j.operator.*; import com.fenbi.mp4j.rpc.IServer; import com.fenbi.mp4j.rpc.Server; import com.fenbi.mp4j.utils.CommUtils; import com.fenbi.mp4j.utils.KryoUtils; import com.fenbi.mp4j.utils.ScatterAllocate; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.ArrayPrimitiveWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.ipc.RPC; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.lang.management.ManagementFactory; import java.lang.reflect.Array; import java.net.*; import java.util.*; import java.util.concurrent.*; /** * ProcessCommSlave is used as multi-processes communication. * @author xialong */ public class ProcessCommSlave { public static final Logger LOG = LoggerFactory.getLogger(ProcessCommSlave.class); public final static String HOST_PORT_DELIM = "###"; public final static int SOCK_CONN_RETRY_TIME = 50; public final static int SOCK_CONN_SLEEP_TIME = 30; private final int slaveNum; private final int rank; private final String rankMsgPrefix; private final String[] slaveHosts; private final int[] slavePorts; private volatile Socket sendSockRef; private volatile Socket recvSockRef; private final ServerSocket recvDataSock; private volatile boolean closed = false; private volatile int curStep = 0; private final ScheduledExecutorService scheduledThreadPool; private volatile int heartbeatExceptionNum = 0; private final IServer server; private volatile Operand operand; private BlockingDeque<MetaData> sendQueue = new LinkedBlockingDeque<>(); private BlockingDeque<MetaData> recvQueue = new LinkedBlockingDeque<>(); private BlockingDeque<MetaData> recvResultQueue = new LinkedBlockingDeque<>(); private final Thread sendDataTask = new Thread() { @Override public void run() { try { while (true) { MetaData metaData = sendQueue.take(); Output output = getOutput(metaData.getDestRank()); operand.send(output, metaData); } } catch (Exception e) { try { exception(e); } catch (Mp4jException e1) { LOG.error("Mp4jException", e); } } } }; private final Thread recvDataTask = new Thread() { @Override public void run() { try { while (true) { MetaData metaData = recvQueue.take(); MetaData recvMetaData = operand.recv(getInput(), metaData); recvResultQueue.put(recvMetaData); } } catch (Exception e) { try { exception(e); } catch (Mp4jException e1) { LOG.error("Mp4jException", e); } } } }; private final Thread sendHeartBeatTask = new Thread() { @Override public void run() { heartbeat(); } }; /** * Process communication constructor, every process have just only one ProcessCommSlave instance. * @param loginName if you use ssh to execute command, you must provide login name, e.g. ssh loginName@host "your command" * @param masterHost master host name * @param masterPort master host port * @throws Mp4jException */ public ProcessCommSlave(String loginName, String masterHost, int masterPort) throws Mp4jException { try { LOG.info("master host:" + masterHost + ", master port:" + masterPort); InetSocketAddress address = new InetSocketAddress(masterHost, masterPort); server = (IServer) RPC.getProxy(IServer.class, IServer.versionID, address, new Configuration()); String host = InetAddress.getLocalHost().getHostName(); recvDataSock = new ServerSocket(0); int port = recvDataSock.getLocalPort(); Text addresses = server.getAllSlavesInfo(new Text(host + HOST_PORT_DELIM + port)); if (address == null) { throw new Mp4jException("slaves connecting master failed, may be this slave restarted or master port is occupied, task failed!"); } String[] slavesAddresses = addresses.toString().split(Server.ADDRESS_DELIM); slaveNum = slavesAddresses.length - 1; rank = Integer.parseInt(slavesAddresses[slaveNum]); rankMsgPrefix = "[rank=" + rank + "] "; LOG.info("this slave recv data port:" + port); LOG.info("slave num:" + slaveNum); LOG.info("slave rank:" + rank); // get pid // get name representing the running Java virtual machine. String name = ManagementFactory.getRuntimeMXBean().getName(); String pid = name.split("@")[0]; LOG.info("Pid is:" + pid); if (slaveNum > 1) { server.killMe(rank, new Text("ssh " + loginName + "@" + host + " \"kill -9 " + pid + "\"")); } else { server.killMe(rank, new Text("kill -9 " + pid)); } slaveHosts = new String[slaveNum]; slavePorts = new int[slaveNum]; for (int i = 0; i < slaveNum; i++) { String[] addr = slavesAddresses[i].split(HOST_PORT_DELIM); slaveHosts[i] = addr[0]; slavePorts[i] = Integer.parseInt(addr[1]); } LOG.info("slaves addresses:"); for (int i = 0; i < slaveNum; i++) { LOG.info(slaveHosts[i] + ":" + slavePorts[i]); } this.scheduledThreadPool = Executors.newScheduledThreadPool(1); scheduledThreadPool.scheduleAtFixedRate(sendHeartBeatTask, 5, 15, TimeUnit.SECONDS); sendDataTask.start(); recvDataTask.start(); } catch (Exception e) { LOG.error("exception!", e); throw new Mp4jException(e.getCause()); } info("this slave init finished!"); } private void heartbeat() { try { if (!closed) { server.heartbeat(rank); } } catch (Exception e) { e.printStackTrace(); heartbeatExceptionNum ++; LOG.error("rank:" + rank + " send heartbeat exception, exception time:" + heartbeatExceptionNum, e); if (heartbeatExceptionNum > 4) { LOG.error("heart beat exception > 4 master may be shutdowned! this slave will be shutdowned..."); System.exit(4); } } } /** * close communication. * @param code close code. * @throws Mp4jException */ public void close(int code) throws Mp4jException { LOG.info("close code=" + code); if (closed) return; if (server != null) { try { server.close(rank, code); } catch (Exception e) { throw new Mp4jException("Exception in close!", e.getCause()); } LOG.info("reduce closed!"); } try { if (recvSockRef != null) { recvSockRef.close(); LOG.info("Recv operand socket closed!"); } if (recvDataSock != null) { recvDataSock.close(); LOG.info("Data server socket closed!"); } if (sendSockRef != null) { sendSockRef.close(); LOG.info("Send operand socket closed!"); } } catch (IOException e) { throw new Mp4jException("exception in slave close!", e.getCause()); } this.operand = null; closed = true; LOG.info("reduce closed!"); } /** * create a file in master * @param content content in a file * @param fileName file name to be write * @throws Exception */ public void writeFile(String content , String fileName) throws Exception { server.writeFile(new Text(content), new Text(fileName)); } /** * send information to master * @param info information * @param onlyRank0 if just rank 0 can send information successfully. * @throws Mp4jException */ public void info(String info, boolean onlyRank0) throws Mp4jException { try { if (!onlyRank0) { server.info(rank, new Text(rankMsgPrefix + info)); } else { if (rank == 0) { server.info(rank, new Text(info)); } } } catch (Exception e) { throw new Mp4jException("exception in slave report!", e.getCause()); } } /** * only rank 0 can send information successfully. * @param info information * @throws Mp4jException */ public void info(String info) throws Mp4jException { info(info, true); } /** * send debug information to master * @param debug debug information * @param onlyRank0 if just rank 0 can send debug information successfully. * @throws Mp4jException */ public void debug(String debug, boolean onlyRank0) throws Mp4jException { try { if (!onlyRank0) { server.debug(rank, new Text(rankMsgPrefix + debug)); } else { if (rank == 0) { server.debug(rank, new Text(debug)); } } } catch (Exception e) { throw new Mp4jException("exception in slave report!", e.getCause()); } } /** * * only rank 0 can send debug information successfully. * @param debug debug information * @throws Mp4jException */ public void debug(String debug) throws Mp4jException { debug(debug, true); } /** * send error information to master. * @param error error info * @throws Mp4jException */ public void error(String error) throws Mp4jException { try { server.error(rank, new Text(rankMsgPrefix + error)); } catch (Exception e) { throw new Mp4jException("exception in slave report!", e.getCause()); } } /** * send exception information to master. * @param e exception * @throws Mp4jException */ public void exception(Exception e) throws Mp4jException { StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); e.printStackTrace(pw); try { error("slave exception:" + sw.toString()); Thread.sleep(5000); close(1); } catch (Exception e1) { throw new Mp4jException("exception in slave report!", e1.getCause()); } } /** * slave number * @return */ public int getSlaveNum() { return slaveNum; } /** * the rank of this process * @return */ public int getRank() { return rank; } private Socket getSendDataSocket(int targetRank) throws InterruptedException, IOException { Socket sendDataSock = null; LOG.debug("send data, target rank:" + targetRank); for (int i = 0; i < SOCK_CONN_RETRY_TIME; i++) { try { String targetAddress = slaveHosts[targetRank]; int targetRecvDataPort = slavePorts[targetRank]; sendDataSock = new Socket(targetAddress, targetRecvDataPort); sendDataSock.setTcpNoDelay(true); sendDataSock.setSoLinger(true, 160); break; } catch (UnknownHostException e) { Thread.sleep(SOCK_CONN_SLEEP_TIME); if (i == SOCK_CONN_RETRY_TIME - 1) throw e; continue; } catch (IOException e) { Thread.sleep(SOCK_CONN_SLEEP_TIME); if (i == SOCK_CONN_RETRY_TIME - 1) throw e; continue; } } sendSockRef = sendDataSock; return sendDataSock; } private Socket getRecvDataSocket() throws IOException { Socket recvDataSock = this.recvDataSock.accept(); LOG.debug("recv socket:" + recvDataSock); recvSockRef = recvDataSock; return recvDataSock; } /** * synchronizing processes * @throws Mp4jException */ public void barrier() throws Mp4jException { try { server.barrier(); } catch (Exception e) { throw new Mp4jException(e); } } private <T> MetaData dynamicBinaryTreeGather(int rootRank, MetaData thisMetaData) throws Mp4jException { try { LOG.info("thismetadata:" + thisMetaData); int thatRank = server.exchange(rank); if ((rank < thatRank || rank == rootRank) && thatRank != rootRank) { LOG.info("this rank:" + rank + " recv data! that rank:" + thatRank); while (true) { recvQueue.put(thisMetaData); MetaData thatMetaData = recvResultQueue.take(); LOG.info("recv thatmetadata:" + thatMetaData); // merge switch (operand.getContainer()) { case ARRAY: thisMetaData.stepIncr(1); thisMetaData.setSum(thisMetaData.getSum() + thatMetaData.getSum()); thisMetaData.setSegNum(thisMetaData.getSegNum() + thatMetaData.getSegNum()); ArrayMetaData thatArrayMetaData = thatMetaData.convertToArrayMetaData(); ArrayMetaData thisArrayMetaData = thisMetaData.convertToArrayMetaData(); // append ranks, froms, tos thisArrayMetaData.getRanks().addAll(thatArrayMetaData.getRanks()); thisArrayMetaData.getSegFroms().addAll(thatArrayMetaData.getSegFroms()); thisArrayMetaData.getSegTos().addAll(thatArrayMetaData.getSegTos()); break; case MAP: thisMetaData.stepIncr(1); thisMetaData.setSum(thisMetaData.getSum() + thatMetaData.getSum()); thisMetaData.setSegNum(1); MapMetaData thatMapMetaData = thatMetaData.convertToMapMetaData(); MapMetaData thisMapMetaData = thisMetaData.convertToMapMetaData(); thisMapMetaData.getRanks().addAll(thatMapMetaData.getRanks()); // merge map, recv process always has 1 map List<Map<String, T>> thatMapDataList = thatMetaData.getMapDataList(); Map<String, T> thisMap = (Map<String, T>)thisMapMetaData.getMapDataList().get(0); for (Map<String, T> thatMap : thatMapDataList) { for (Map.Entry<String, T> entry : thatMap.entrySet()) { thisMap.put(entry.getKey(), entry.getValue()); } } // reset dataNums = 1 thisMapMetaData.setDataNums(Arrays.asList(thisMap.size())); break; default: throw new Mp4jException("unsupported container:" + operand.getContainer()); } LOG.info("merged thisMetaData:" + thisMetaData); if (thisMetaData.getSum() != slaveNum) { thatRank = server.exchange(rank); if (!((rank < thatRank || rank == rootRank) && thatRank != rootRank)) { break; } } else { if (rank != rootRank) { throw new Mp4jException("dynamic binary gather must stop at rootRank:" + rootRank + ", thisRank:" + rank); } barrier(); LOG.info("root gather finished"); return thisMetaData; } } } LOG.info("this rank:" + rank + " send data! that rank:" + thatRank); thisMetaData.setDestRank(thatRank); LOG.info("send thatmetadata:" + thisMetaData); sendQueue.put(thisMetaData); barrier(); LOG.info("gather finished"); return thisMetaData; } catch (Exception e) { throw new Mp4jException(e); } } /** * Takes elements from all processes and gathers them to root process, the data container is array, * process i send data interval [sendfroms[i], sendtos[i]) to root process, placed in the same positions, * s.t. sendfroms[i] ≤ sendtos[i], sendfroms[i] ≥ sendtos[i-1] * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param sendfroms sending start positions, included. * @param sendtos sending end positions, excluded * @param rootRank root rank * @return if this process is root, returned is gathered elements, otherwise, * is invalid array, contains intermediate results. * @throws Mp4jException */ public <T> T gatherArray(T arrData, Operand operand, int[] sendfroms, int[] sendtos, int rootRank) throws Mp4jException { if (sendfroms.length != slaveNum) { throw new Mp4jException("sendfroms array length must be equal to slaveNum"); } if (sendtos.length != slaveNum) { throw new Mp4jException("sendtos array length must be equal to slaveNum"); } if (slaveNum == 1) { return arrData; } CommUtils.isfromsTosLegal(sendfroms, sendtos); try { this.operand = operand; this.operand.setCollective(Collective.GATHER); this.operand.setContainer(Container.ARRAY); ArrayMetaData<T> arrayMetaData = new ArrayMetaData<>(); arrayMetaData.setSrcRank(rank) .setDestRank(-1) .setStep(0) .setSum(1) .setCollective(Collective.GATHER) .insert(rank, sendfroms[rank], sendtos[rank]); arrayMetaData.setArrData(arrData); return (T)dynamicBinaryTreeGather(rootRank, arrayMetaData).getArrData(); } catch (Exception e) { throw new Mp4jException(e); } } /** * Takes elements from all processes and gathers them to root process, the data container is map, * the data in map regardless of the order, if different process have same keys, only one key(random) will * be saved in root process. * @param mapData key is {@code String}, value is any object * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param rootRank root rank * @return if this process is root, returned is gathered elements, otherwise, * is invalid map, contained intermediate comm result or null. * @throws Mp4jException */ public <T> Map<String, T> gatherMap(Map<String, T> mapData, Operand operand, int rootRank) throws Mp4jException { if (slaveNum == 1) { return mapData; } try { this.operand = operand; this.operand.setCollective(Collective.GATHER); this.operand.setContainer(Container.MAP); MapMetaData<T> mapMetaData = new MapMetaData<>(); mapMetaData.setSrcRank(rank) .setDestRank(-1) .setStep(0) .setSum(1) .setCollective(Collective.GATHER) .insert(rank, mapData.size()); List<Map<String, T>> listMapList = new ArrayList<>(); listMapList.add(mapData); mapMetaData.setMapDataList(listMapList); return (Map<String, T>)dynamicBinaryTreeGather(rootRank, mapMetaData).getMapDataList().get(0); } catch (Exception e) { throw new Mp4jException(e); } } /** * Takes elements from all processes and gathers them to all processes. * the operation can be viewed as a combination of gather and broadcast. * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param froms the start position of process i to gathered is froms[i] * @param tos the end position of process i to gathered is tos[i] * @return allgathered array * @throws Mp4jException */ public <T> T allgatherArray(T arrData, Operand operand, int[] froms, int[] tos) throws Mp4jException { if (froms.length != slaveNum) { throw new Mp4jException("froms array length must be equal to slaveNum"); } if (tos.length != slaveNum) { throw new Mp4jException("tos array length must be equal to slaveNum"); } if (slaveNum == 1) { return arrData; } CommUtils.isfromsTosLegal(froms, tos); try { this.operand = operand; this.operand.setCollective(Collective.ALL_GATHER); this.operand.setContainer(Container.ARRAY); ArrayMetaData<T> arrayMetaData = new ArrayMetaData<>(); arrayMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setStep(0) .setSum(1) .setCollective(Collective.ALL_GATHER) .insert(rank, froms[rank], tos[rank]); arrayMetaData.setArrData(arrData); return (T)ringAllgather(arrayMetaData, Container.ARRAY).getArrData(); } catch (Exception e) { throw new Mp4jException(e); } } /** * Takes elements from all processes and gathers them to all processes. * Similar with {@link #allgatherArray(Object, Operand, int[], int[])}, this container is map. * @param mapData map data * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @return map list * @throws Mp4jException */ public <T> List<Map<String, T>> allgatherMap(Map<String, T> mapData, Operand operand) throws Mp4jException { if (slaveNum == 1) { return Arrays.asList(mapData); } try { this.operand = operand; this.operand.setCollective(Collective.ALL_GATHER); this.operand.setContainer(Container.MAP); MapMetaData<T> mapMetaData = new MapMetaData<>(); mapMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setStep(0) .setSum(1) .setCollective(Collective.ALL_GATHER) .insert(rank, mapData.size()); List<Map<String, T>> listMap = new ArrayList<>(slaveNum); listMap.add(mapData); mapMetaData.setMapDataList(listMap); MapMetaData<T> retMapMetaData = (MapMetaData<T>)ringAllgather(mapMetaData, Container.MAP).convertToMapMetaData(); return retMapMetaData.getMapDataList(); } catch (Exception e) { throw new Mp4jException(e); } } private <T> MetaData<T> ringAllgather(MetaData thisMetaData, Container container) throws Mp4jException { try { List<Map<String, T>> retMapList = null; if (container == Container.MAP){ retMapList = new ArrayList<>(slaveNum); for (int i = 0; i < slaveNum; i++) { retMapList.add(Collections.emptyMap()); } retMapList.set(rank, (Map<String, T>)thisMetaData.getMapDataList().get(0)); } // send first block sendQueue.put(thisMetaData); // recv slaveNum - 1 times, send slaveNum - 2 for (int step = 1; step < slaveNum; step++) { recvQueue.push(thisMetaData); MetaData<T> thatMetaData = recvResultQueue.take(); if (container == Container.MAP) { int realOriginRank = thatMetaData.convertToMapMetaData().getRanks().get(0); retMapList.set(realOriginRank, thatMetaData.getMapDataList().get(0)); } thisMetaData = thatMetaData; thatMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setCollective(Collective.ALL_GATHER); if (step < slaveNum - 1) { sendQueue.put(thatMetaData); } } if (container == Container.MAP) { thisMetaData.setMapDataList(retMapList); } barrier(); } catch (Exception e) { throw new Mp4jException(e); } return thisMetaData; } /** * Broadcast array in root process to all other processe(included itself). * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param from start position to broadcast * @param to end position to broadcast * @param rootRank root rank * @return root array, interval is [from, to) * @throws Mp4jException */ public <T> T broadcastArray(T arrData, Operand operand, int from, int to, int rootRank) throws Mp4jException { if (slaveNum == 1) { return arrData; } CommUtils.isFromToLegal(from, to); try { int[] froms = new int[slaveNum]; int[] tos = new int[slaveNum]; int avg = (to - from) / slaveNum; int fromidx = from; for (int r = 0; r < slaveNum; r++) { froms[r] = fromidx; tos[r] = fromidx + avg; fromidx += avg; } tos[slaveNum - 1] = to; T scatterArr = scatterArray(arrData, operand, froms, tos, rootRank); return allgatherArray(scatterArr, operand, froms, tos); } catch (Exception e) { throw new Mp4jException(e); } } /** * Broadcast single value in root process to all other processes(included itself). * @param value value to be broadcast(only root process is valid) * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param rootRank root rank * @return the value of root process. * @throws Mp4jException */ public <T> T broadcast(T value, Operand operand, int rootRank) throws Mp4jException { if (operand instanceof DoubleOperand) { double []doubleArr = new double[1]; doubleArr[0] = (Double) value; doubleArr = broadcastArray(doubleArr, operand, 0, doubleArr.length, rootRank); return (T)Double.valueOf(doubleArr[0]); } else if (operand instanceof FloatOperand) { float []floatArr = new float[1]; floatArr[0] = (Float) value; floatArr = broadcastArray(floatArr, operand, 0, floatArr.length, rootRank); return (T)Float.valueOf(floatArr[0]); } else if (operand instanceof IntOperand) { int []intArr = new int[1]; intArr[0] = (Integer) value; intArr = broadcastArray(intArr, operand, 0, intArr.length, rootRank); return (T)Integer.valueOf(intArr[0]); } else if (operand instanceof LongOperand) { long []longArr = new long[1]; longArr[0] = (Long) value; longArr = broadcastArray(longArr, operand, 0, longArr.length, rootRank); return (T)Long.valueOf(longArr[0]); } else if (operand instanceof ObjectOperand) { T []objectArr = (T[]) Array.newInstance(value.getClass(), 1); objectArr[0] = value; objectArr = broadcastArray(objectArr, operand, 0, objectArr.length, rootRank); return objectArr[0]; } else if (operand instanceof StringOperand) { String []stringArr = new String[1]; stringArr[0] = (String) value; stringArr = broadcastArray(stringArr, operand, 0, stringArr.length, rootRank); return (T)stringArr[0]; } else if (operand instanceof ShortOperand) { short []shortArr = new short[1]; shortArr[0] = (Short) value; shortArr = broadcastArray(shortArr, operand, 0, shortArr.length, rootRank); return (T)Short.valueOf(shortArr[0]); } else if (operand instanceof ByteOperand) { byte []byteArr = new byte[1]; byteArr[0] = (Byte) value; byteArr = broadcastArray(byteArr, operand, 0, byteArr.length, rootRank); return (T)Byte.valueOf(byteArr[0]); } else { throw new Mp4jException("unknown operand:" + operand); } } /** * Broadcast map in root process to all other processes(included itself), * similar with {@link #broadcastArray(Object, Operand, int, int, int)}, but this container is map. * @param mapData map data * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param rootRank root rank * @return the root map. * @throws Mp4jException */ public <T> Map<String, T> broadcastMap(Map<String, T> mapData, Operand operand, int rootRank) throws Mp4jException { if (slaveNum == 1) { return mapData; } try { List<Map<String, T>> listMapData = new ArrayList<>(); if (rank == rootRank) { listMapData = new ArrayList<>(slaveNum); for (int i = 0; i < slaveNum; i++) { listMapData.add(new HashMap<>((int)((mapData.size() / slaveNum) * 1.2))); } for (Map.Entry<String, T> entry : mapData.entrySet()) { String key = entry.getKey(); T val = entry.getValue(); int idx = key.hashCode() % slaveNum; if (idx < 0) { idx += slaveNum; } listMapData.get(idx).put(key, val); } } Map<String, T> scatteredMap = scatterMap(listMapData, operand, rootRank); List<Map<String, T>> broadcasedMapList = allgatherMap(scatteredMap, operand); Map<String, T> retMap = broadcasedMapList.get(0); for (int i = 1; i < broadcasedMapList.size(); i++) { Map<String, T> retMapTemp = broadcasedMapList.get(i); for (Map.Entry<String, T> entry : retMapTemp.entrySet()) { retMap.put(entry.getKey(), entry.getValue()); } } return retMap; } catch (Exception e) { throw new Mp4jException(e); } } /** * Send chunks of an array to different processes, the data container is array. * process i receive data interval [recvfroms[i], recvtos[i]) from root process, placed in the same positions, * s.t. recvfroms[i] ≤ recvtos[i], recvfroms[i] ≥ recvtos[i-1] * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param recvfroms receiving start positions, included * @param recvtos receiving end positions, excluded * @param rootRank root rank * @return rank i receives data from root process, placed in [recvfroms[i], recvtos[i]) * @throws Mp4jException */ public <T> T scatterArray(T arrData, Operand operand, int[] recvfroms, int[] recvtos, int rootRank) throws Mp4jException { try { if (recvfroms.length != slaveNum) { throw new Mp4jException("recvfroms array length must be equal to slaveNum"); } if (recvtos.length != slaveNum) { throw new Mp4jException("recvtos array length must be equal to slaveNum"); } if (slaveNum == 1) { return arrData; } CommUtils.isfromsTosLegal(recvfroms, recvtos); this.operand = operand; this.operand.setCollective(Collective.SCATTER); this.operand.setContainer(Container.ARRAY); ArrayMetaData<T> arrayMetaData = new ArrayMetaData<>(); arrayMetaData.setSrcRank(rank) .setDestRank(-1) .setStep(0) .setSum(0) .setCollective(Collective.SCATTER); for (int i = 0; i < slaveNum; i++) { arrayMetaData.insert(i, recvfroms[i], recvtos[i]); } arrayMetaData.setArrData(arrData); return (T)binaryTreeScatter(rootRank, arrayMetaData, Container.ARRAY).convertToArrayMetaData().getArrData(); } catch (Exception e) { throw new Mp4jException(e); } } /** * Send chunks of data to different processes, the data container is map. * rank i receive data mapDataList.get(i) from root process. * @param mapDataList list of maps to be scattered. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param rootRank root rank * @return rank i receives data mapDataList.get(i) from root process. * @throws Mp4jException */ public <T> Map<String, T> scatterMap(List<Map<String, T>> mapDataList, Operand operand, int rootRank) throws Mp4jException { try { if (rank == rootRank && mapDataList.size() != slaveNum) { throw new Mp4jException("mapDataList size must be equal to slaveNum"); } if (slaveNum == 1) { return mapDataList.get(0); } LOG.info("entry scater map..."); this.operand = operand; this.operand.setCollective(Collective.SCATTER); this.operand.setContainer(Container.MAP); MapMetaData<T> MapMetaData = new MapMetaData<>(); MapMetaData.setSrcRank(rank) .setDestRank(-1) .setStep(0) .setSum(0) .setCollective(Collective.SCATTER); if (rank == rootRank) { for (int i = 0; i < slaveNum; i++) { MapMetaData.insert(i, mapDataList.get(i).size()); } MapMetaData.setMapDataList(mapDataList); } LOG.info("will call binary tree scatter..."); MapMetaData<T> metaData = (MapMetaData<T>)binaryTreeScatter(rootRank, MapMetaData, Container.MAP).convertToMapMetaData(); Map<String, T> retMap = null; List<Map<String, T>> listMapData = metaData.getMapDataList(); for (int i = 0; i < metaData.getSegNum(); i++) { int curRank = metaData.getRank(i); if (curRank == rank) { retMap = listMapData.get(i); } } listMapData.clear(); if (retMap == null) { throw new Mp4jException("scatter error retmap must not be null!"); } return retMap; } catch (Exception e) { throw new Mp4jException(e); } } protected <T> List<Map<String, T>> scatterMapSpecial(List<List<Map<String, T>>> mapDataListList, Operand operand, int rootRank) throws Mp4jException { try { if (rank == rootRank && mapDataListList.size() != slaveNum) { throw new Mp4jException("mapDataList size must be equal to slaveNum"); } if (slaveNum == 1) { return mapDataListList.get(0); } LOG.info("entry special scatter map..."); this.operand = operand; this.operand.setCollective(Collective.SCATTER); this.operand.setContainer(Container.MAP); MapMetaData<T> mapMetaData = new MapMetaData<>(); mapMetaData.setSrcRank(rank) .setDestRank(-1) .setStep(0) .setSum(0) .setCollective(Collective.SCATTER); if (rank == rootRank) { List<Map<String, T>> mapDataListForScatter = new ArrayList<>(); for (int i = 0; i < slaveNum; i++) { List<Map<String, T>> mapDataList = mapDataListList.get(i); for (int j = 0; j < mapDataList.size(); j++) { mapDataListForScatter.add(mapDataList.get(j)); mapMetaData.insert(i, mapDataList.get(j).size()); } } LOG.info("special root map data list size:" + mapDataListForScatter.size() + ", metadata:" + mapMetaData); mapMetaData.setMapDataList(mapDataListForScatter); } LOG.info("will call binary tree scatter..."); MapMetaData<T> metaData = (MapMetaData<T>)binaryTreeScatter(rootRank, mapMetaData, Container.MAP).convertToMapMetaData(); LOG.info("binary tree scatter finished!"); List<Map<String, T>> retMapList = new ArrayList<>(); List<Map<String, T>> listMapData = metaData.getMapDataList(); for (int i = 0; i < metaData.getSegNum(); i++) { int curRank = metaData.getRank(i); if (curRank == rank) { retMapList.add(listMapData.get(i)); } } listMapData.clear(); return retMapList; } catch (Exception e) { throw new Mp4jException(e); } } private <T> MetaData<T> scatterSend(MetaData metaData, List<Integer> sendInfo, Container container) throws Mp4jException { MetaData newMetaData = MetaData.newMetaData(container); int newSrcRank = sendInfo.get(0); int newDestRank = sendInfo.get(1); int newRankFrom = sendInfo.get(2); int newRankTo = sendInfo.get(3); newMetaData.setSrcRank(newSrcRank); newMetaData.setDestRank(newDestRank); newMetaData.setCollective(Collective.SCATTER); int segNum = metaData.getSegNum(); if (container == Container.ARRAY) { List<Integer> ranks = metaData.convertToArrayMetaData().getRanks(); List<Integer> froms = metaData.convertToArrayMetaData().getSegFroms(); List<Integer> tos = metaData.convertToArrayMetaData().getSegTos(); for (int i = 0; i < segNum; i++) { int rank = ranks.get(i); int from = froms.get(i); int to = tos.get(i); if (rank >= newRankFrom && rank <= newRankTo) { newMetaData.insert(rank, from, to); } } newMetaData.setArrData(metaData.getArrData()); } else if (container == Container.MAP) { List<Integer> dataNums = metaData.convertToMapMetaData().getDataNums(); List<Integer> ranks = metaData.convertToMapMetaData().getRanks(); List<Map<String, T>> mapList = metaData.convertToMapMetaData().getMapDataList(); List<Map<String, T>> newMapList = new ArrayList<>(); for (int i = 0; i < segNum; i++) { int rank = ranks.get(i); int dataNum = dataNums.get(i); if (rank >= newRankFrom && rank <= newRankTo) { newMetaData.insert(rank, dataNum); newMapList.add(mapList.get(i)); } } newMetaData.setMapDataList(newMapList); } else { throw new Mp4jException("unsupport MetaData type in scatter send"); } return newMetaData; } private <T> MetaData<T> binaryTreeScatter(int rootRank, MetaData thisMetaData, Container container) throws Mp4jException { try { LOG.info("entry binary tree scatter ..."); Map<Integer, List<List<Integer>>> allocateMap = ScatterAllocate.allocate(slaveNum, rootRank); Map<Integer, Integer> recvNumMap = ScatterAllocate.recvNum(allocateMap); // get this rank's send task list List<List<Integer>> thisSendTaskList = allocateMap.getOrDefault(rank, Collections.emptyList()); // get recv number Integer recvNum = recvNumMap.getOrDefault(rank, 0); if (recvNum >= 2) { throw new Mp4jException("scatter error, recv num must <= 1"); } LOG.info("this send task list:" + thisSendTaskList); LOG.info("this recv num:" + recvNum); // every process only recv once data except the root process int thisSendCursor = 0; if (rank == rootRank) { if (recvNum > 0) { List<Integer> half0 = thisSendTaskList.get(thisSendCursor++); sendQueue.put(scatterSend(thisMetaData, half0, container)); if (thisSendCursor < thisSendTaskList.size()) { List<Integer> half1 = thisSendTaskList.get(thisSendCursor); int toRank = half1.get(2); if (toRank == 0 || toRank == slaveNum / 2) { sendQueue.put(scatterSend(thisMetaData, half1, container)); thisSendCursor ++; } } } else { for (; thisSendCursor < thisSendTaskList.size(); thisSendCursor++) { sendQueue.put(scatterSend(thisMetaData, thisSendTaskList.get(thisSendCursor), container)); } } } // recv at most once & send if (recvNum > 0) { recvQueue.push(thisMetaData); MetaData<T> thatMetaData = recvResultQueue.take(); thisMetaData = thatMetaData; thisMetaData.setCollective(Collective.SCATTER); for (int cursor = thisSendCursor; cursor < thisSendTaskList.size(); cursor++) { sendQueue.put(scatterSend(thisMetaData, thisSendTaskList.get(cursor), container)); } } barrier(); } catch (Exception e) { throw new Mp4jException(e); } return thisMetaData; } /** * ReduceScatterArray operation can ve viewed as a combination of {@link #reduceArray(Object, Operand, IOperator, int, int, int)} * and {@link #scatterArray(Object, Operand, int[], int[], int)}. * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param from start position of arrData to be ReduceScattered * @param counts rank i receive number of counts[i] elements reducescatter result. * @return rank 0 receive interval is [from, counts[0]), * rank 1 receive interval is [from+count[0], from+count[0]+count[1]), ... * @throws Mp4jException */ public <T> T reduceScatterArray(T arrData, Operand operand, IOperator operator, int from, int[] counts) throws Mp4jException { if (counts.length != slaveNum) { throw new Mp4jException("counts array length must be equal to slaveNum"); } if (slaveNum == 1) { return arrData; } CommUtils.isFromCountsLegal(from, counts); try { int []froms = CommUtils.getFromsFromCount(from, counts, slaveNum); int []tos = CommUtils.getTosFromCount(from, counts, slaveNum); this.operand = operand; this.operand.setCollective(Collective.REDUCE_SCATTER); this.operand.setContainer(Container.ARRAY); this.operand.setOperator(operator); int sendRank = (rank - 1 + slaveNum) % slaveNum; ArrayMetaData<T> arrayMetaData = new ArrayMetaData<>(); arrayMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setStep(0) .setSum(1) .setCollective(Collective.REDUCE_SCATTER) .insert(rank, froms[sendRank], tos[sendRank]); arrayMetaData.setArrData(arrData); return (T)ringReduceScatter(arrayMetaData, Container.ARRAY).getArrData(); } catch (Exception e) { throw new Mp4jException(e); } } protected <T> List<Map<String, T>> reduceScatterMapSpecial(List<List<Map<String, T>>> mapDataListList, Operand operand, IOperator operator) throws Mp4jException { if (slaveNum == 1) { return mapDataListList.get(0); } if (mapDataListList.size() != slaveNum) { throw new Mp4jException("mapDataListList size=" + mapDataListList.size() + ", must be equal to slaveNum=" + slaveNum); } int blockNum = mapDataListList.get(0).size(); try { this.operand = operand; this.operand.setCollective(Collective.REDUCE_SCATTER); this.operand.setContainer(Container.MAP); this.operand.setOperator(operator); MapMetaData<T> mapMetaData = new MapMetaData<>(); mapMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setStep(0) .setSum(1) .setCollective(Collective.REDUCE_SCATTER); List<Map<String, T>> allmapDataList = new ArrayList<>(); for (int i = 0; i < mapDataListList.size(); i++) { List<Map<String, T>> maps = mapDataListList.get(i); for (int j = 0; j < maps.size(); j++) { allmapDataList.add(maps.get(j)); } } List<Map<String, T>> sendMapList = new ArrayList<>(); int sendIdx = (rank - 1 + slaveNum) % slaveNum; for (int t = 0; t < blockNum; t++) { int idx = sendIdx * blockNum + t; mapMetaData.insert(rank, allmapDataList.get(idx).size()); sendMapList.add(allmapDataList.get(idx)); } mapMetaData.setMapDataList(sendMapList); // send first block sendQueue.put(mapMetaData); // recv slaveNum - 1 times, send slaveNum - 2 for (int step = 1; step < slaveNum; step++) { MapMetaData recvMapMetaData = new MapMetaData(); List<Map<String, T>> recvMapList = new ArrayList<>(); int recvIdx = (rank - (step + 1) + slaveNum) % slaveNum; for (int t = 0; t < blockNum; t++) { int idx = recvIdx * blockNum + t; recvMapMetaData.insert(rank, allmapDataList.get(idx).size()); recvMapList.add(allmapDataList.get(idx)); } recvMapMetaData.setMapDataList(recvMapList); recvQueue.push(recvMapMetaData); MetaData<T> thatMetaData = recvResultQueue.take(); mapMetaData = thatMetaData.convertToMapMetaData(); mapMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setCollective(Collective.REDUCE_SCATTER); if (step < slaveNum - 1) { sendQueue.put(mapMetaData); } } barrier(); return mapMetaData.getMapDataList(); } catch (Exception e) { throw new Mp4jException(e); } } /** * reduceScatterMap operation can be viewed as a combination of {@link #reduceMap(Map, Operand, IOperator, int)} * and {@link #scatterMap(List, Operand, int)}. * @param mapDataList list of map * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @return rank i process will receive reduce of mapDataList.get(i) in all process * @throws Mp4jException */ public <T> Map<String, T> reduceScatterMap(List<Map<String, T>> mapDataList, Operand operand, IOperator operator) throws Mp4jException { if (slaveNum == 1) { return mapDataList.get(0); } if (mapDataList.size() != slaveNum) { throw new Mp4jException("mapDataList size=" + mapDataList.size() + ", must be equal to slaveNum=" + slaveNum); } try { this.operand = operand; this.operand.setCollective(Collective.REDUCE_SCATTER); this.operand.setContainer(Container.MAP); this.operand.setOperator(operator); MapMetaData<T> mapMetaData = new MapMetaData<>(); mapMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setStep(0) .setSum(1) .setCollective(Collective.REDUCE_SCATTER); mapMetaData.setMapDataList(mapDataList); MapMetaData<T> retMapMetaData = (MapMetaData<T>)ringReduceScatter(mapMetaData, Container.MAP).convertToMapMetaData(); return retMapMetaData.getMapDataList().get(0); } catch (Exception e) { throw new Mp4jException(e); } } private <T> MetaData<T> ringReduceScatter(MetaData thisMetaData, Container container) throws Mp4jException { try { List<Map<String, T>> storeMapList = null; if (container == Container.MAP){ storeMapList = thisMetaData.getMapDataList(); int sendIdx = (rank - 1 + slaveNum) % slaveNum; thisMetaData.insert(rank, storeMapList.get(sendIdx).size()); thisMetaData.setMapDataList(Arrays.asList(storeMapList.get(sendIdx))); } // send first block sendQueue.put(thisMetaData); // recv slaveNum - 1 times, send slaveNum - 2 for (int step = 1; step < slaveNum; step++) { if (container == Container.ARRAY) { recvQueue.push(thisMetaData); } else if (container == Container.MAP) { MapMetaData recvMapMetaData = new MapMetaData(); int recvIdx = (rank - (step + 1) + slaveNum) % slaveNum; recvMapMetaData.insert(rank, storeMapList.get(recvIdx).size()); recvMapMetaData.setMapDataList(Arrays.asList(storeMapList.get(recvIdx))); recvQueue.push(recvMapMetaData); } MetaData<T> thatMetaData = recvResultQueue.take(); thisMetaData = thatMetaData; thatMetaData.setSrcRank(rank) .setDestRank((rank + 1) % slaveNum) .setCollective(Collective.REDUCE_SCATTER); if (step < slaveNum - 1) { sendQueue.put(thatMetaData); } } barrier(); } catch (Exception e) { throw new Mp4jException(e); } return thisMetaData; } /** * reduce all array elements, reduced result locate in root process * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param from start position to reduce * @param to end position to reduce * @param rootRank root rank * @return if this process is root, reduced array is returned, otherwise, * is invalid array, contains intermediate result. * @throws Mp4jException */ public <T> T reduceArray(T arrData, Operand operand, IOperator operator, int from, int to, int rootRank) throws Mp4jException { if (slaveNum == 1) { return arrData; } CommUtils.isFromToLegal(from, to); try { int[] counts = new int[slaveNum]; int avg = (to - from) / slaveNum; for (int r = 0; r < slaveNum; r++) { counts[r] = avg; } counts[slaveNum - 1] = (to - from) - ((slaveNum - 1) * avg); int[] froms = new int[slaveNum]; int[] tos = new int[slaveNum]; int fromidx = from; for (int r = 0; r < slaveNum; r++) { froms[r] = fromidx; tos[r] = fromidx + avg; fromidx += avg; } tos[slaveNum - 1] = to; T reduceScatteredArr = reduceScatterArray(arrData, operand, operator, from, counts); return gatherArray(reduceScatteredArr, operand, froms, tos, rootRank); } catch (Exception e) { throw new Mp4jException(e); } } /** * single value reduce * @param value value to be reduced * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param rootRank root rank * @return if this process is root, reduced value is returned, otherwise, is a invalid value. * @throws Mp4jException */ public <T> T reduce(T value, Operand operand, IOperator operator, int rootRank) throws Mp4jException { if (operand instanceof DoubleOperand) { double []doubleArr = new double[1]; doubleArr[0] = (Double) value; doubleArr = reduceArray(doubleArr, operand, operator, 0, doubleArr.length, rootRank); return (T)Double.valueOf(doubleArr[0]); } else if (operand instanceof FloatOperand) { float []floatArr = new float[1]; floatArr[0] = (Float) value; floatArr = reduceArray(floatArr, operand, operator, 0, floatArr.length, rootRank); return (T)Float.valueOf(floatArr[0]); } else if (operand instanceof IntOperand) { int []intArr = new int[1]; intArr[0] = (Integer) value; intArr = reduceArray(intArr, operand, operator, 0, intArr.length, rootRank); return (T)Integer.valueOf(intArr[0]); } else if (operand instanceof LongOperand) { long []longArr = new long[1]; longArr[0] = (Long) value; longArr = reduceArray(longArr, operand, operator, 0, longArr.length, rootRank); return (T)Long.valueOf(longArr[0]); } else if (operand instanceof ObjectOperand) { T []objectArr = (T[]) Array.newInstance(value.getClass(), 1); objectArr[0] = value; objectArr = reduceArray(objectArr, operand, operator, 0, objectArr.length, rootRank); return objectArr[0]; } else if (operand instanceof StringOperand) { String []stringArr = new String[1]; stringArr[0] = (String) value; stringArr = reduceArray(stringArr, operand, operator, 0, stringArr.length, rootRank); return (T)stringArr[0]; } else if (operand instanceof ShortOperand) { short []shortArr = new short[1]; shortArr[0] = (Short) value; shortArr = reduceArray(shortArr, operand, operator, 0, shortArr.length, rootRank); return (T)Short.valueOf(shortArr[0]); } else if (operand instanceof ByteOperand) { byte []byteArr = new byte[1]; byteArr[0] = (Byte) value; byteArr = reduceArray(byteArr, operand, operator, 0, byteArr.length, rootRank); return (T)Byte.valueOf(byteArr[0]); } else { throw new Mp4jException("unknown operand:" + operand); } } /** * Similar with {@link #reduceArray(Object, Operand, IOperator, int, int, int)}, * but the container is map. * @param mapData map data * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param rootRank root rank * @return if this process is root, reduced map is returned, otherwise, invalid map or null is returned. * @throws Mp4jException */ public <T> Map<String, T> reduceMap(Map<String, T> mapData, Operand operand, IOperator operator, int rootRank) throws Mp4jException { if (slaveNum == 1) { return mapData; } try { List<Map<String, T>> listMapData = new ArrayList<>(slaveNum); for (int i = 0; i < slaveNum; i++) { listMapData.add(new HashMap<>(Math.max((int)((mapData.size() / slaveNum) * 1.2), 1))); } for (Map.Entry<String, T> entry : mapData.entrySet()) { String key = entry.getKey(); T val = entry.getValue(); int idx = key.hashCode() % slaveNum; if (idx < 0) { LOG.info(key + ", code:" + idx); idx += slaveNum; } listMapData.get(idx).put(key, val); } Map<String, T> reduceScatteredMap = reduceScatterMap(listMapData, operand, operator); return gatherMap(reduceScatteredMap, operand, rootRank); } catch (Exception e) { throw new Mp4jException(e); } } protected static class Mp4jSetSerializer<T> extends Serializer<Set<T>> { Serializer<T> valSerializer; Class<T> valType; public Mp4jSetSerializer(Serializer<T> valSerializer, Class<T> valType) { this.valSerializer = valSerializer; this.valType = valType; } @Override public void write(Kryo kryo, Output output, Set<T> object) { output.writeInt(object.size()); for (T val : object) { valSerializer.write(kryo, output, val); } } @Override public Set<T> read(Kryo kryo, Input input, Class<Set<T>> type) { int size = input.readInt(); Set<T> set = new HashSet<>(size); for (int i = 0; i < size; i++) { set.add(valSerializer.read(kryo, input, this.valType)); } return set; } } protected static class Mp4jListSerializer<T> extends Serializer<List<T>> { Serializer<T> valSerializer; Class<T> valType; public Mp4jListSerializer(Serializer<T> valSerializer, Class<T> valType) { this.valSerializer = valSerializer; this.valType = valType; } @Override public void write(Kryo kryo, Output output, List<T> object) { output.writeInt(object.size()); for (T val : object) { valSerializer.write(kryo, output, val); } } @Override public List<T> read(Kryo kryo, Input input, Class<List<T>> type) { int size = input.readInt(); List<T> list = new ArrayList<T>(size); for (int i = 0; i < size; i++) { list.add(valSerializer.read(kryo, input, this.valType)); } return list; } } /** * Set union, the set with the same key will be reduced(union) together in the root process. * @param mapData map set data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element obejct class * @return if this process is root, the set with the same key will be reduced together, * otherwise, invalid map or null is returned. * @throws Mp4jException */ public <T> Map<String, Set<T>> reduceMapSetUnion(Map<String, Set<T>> mapData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jSetSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<Set<T>>() { @Override public Set<T> apply(Set<T> o1, Set<T> o2) { for (T val : o2) { o1.add(val); } return o1; } }; return reduceMap(mapData, operand, operator, rootRank); } /** * Set union * @param setData set data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return if this process is root, set unison is returned, * otherwise invalid set or null is returned. * @throws Mp4jException */ public <T> Set<T> reduceSetUnion(Set<T> setData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, Set<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", setData); Map<String, Set<T>> mapReturn = reduceMapSetUnion(mapTemp, rootRank, elementSerializer, elementType); if (mapReturn != null) { return mapReturn.get("key"); } else { return null; } } /** * Set intersection, the set with the same key will be reduced(intersect) together. * @param mapData map set data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return if this process is root, the set with the same key will be reduced(intersect) together. * otherwise, invalid map is returned. * @throws Mp4jException */ public <T> Map<String, Set<T>> reduceMapSetIntersection(Map<String, Set<T>> mapData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jSetSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<Set<T>>() { @Override public Set<T> apply(Set<T> o1, Set<T> o2) { o1.retainAll(o2); return o1; } }; return reduceMap(mapData, operand, operator, rootRank); } /** * Set intersection. * @param setData set data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return if this process is root, intersection is returned, * otherwise, invalid set or null is returned. * @throws Mp4jException */ public <T> Set<T> reduceSetIntersection(Set<T> setData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, Set<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", setData); Map<String, Set<T>> mapReturn = reduceMapSetIntersection(mapTemp, rootRank, elementSerializer, elementType); if (mapReturn != null) { return mapReturn.get("key"); } else { return null; } } /** * List concat, the lists with the same key will be reduced(concat) together. * @param mapData map list data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return if this process is root, the lists with the same key will be reduced(concat) together, * otherwise, invalid map or null is returned. * @throws Mp4jException */ public <T> Map<String, List<T>> reduceMapListConcat(Map<String, List<T>> mapData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jListSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<List<T>>() { @Override public List<T> apply(List<T> o1, List<T> o2) { for (T val : o2) { o1.add(val); } return o1; } }; return reduceMap(mapData, operand, operator, rootRank); } /** * List concat. * @param listData list data * @param rootRank root rank * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return if this process is root, the concated list is returned, * otherwise, invalid list or null is returned. * @throws Mp4jException */ public <T> List<T> reduceListConcat(List<T> listData, int rootRank, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, List<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", listData); Map<String, List<T>> mapReturn = reduceMapListConcat(mapTemp, rootRank, elementSerializer, elementType); if (mapReturn != null) { return mapReturn.get("key"); } else { return null; } } /** * Different with reduce operation which only root process contains final reduced result, * while all process receive the same reduced result in allreduce operation. * * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param from start position to reduce * @param to end position to reduce * @return arrData[i](rank) = reduce(arrData[i](rank_0), arrData[i](rank_1), ..., arrData[i](rank_slavenumber-1)), * @throws Mp4jException */ public <T> T allreduceArray(T arrData, Operand operand, IOperator operator, int from, int to) throws Mp4jException { if (slaveNum == 1) { return arrData; } CommUtils.isFromToLegal(from, to); try { int[] counts = new int[slaveNum]; int avg = (to - from) / slaveNum; for (int r = 0; r < slaveNum; r++) { counts[r] = avg; } counts[slaveNum - 1] = (to - from) - ((slaveNum - 1) * avg); int[] froms = new int[slaveNum]; int[] tos = new int[slaveNum]; int fromidx = from; for (int r = 0; r < slaveNum; r++) { froms[r] = fromidx; tos[r] = fromidx + avg; fromidx += avg; } tos[slaveNum - 1] = to; T reduceScatteredArr = reduceScatterArray(arrData, operand, operator, from, counts); return allgatherArray(reduceScatteredArr, operand, froms, tos); } catch (Exception e) { throw new Mp4jException(e); } } /** * Similar with {@link #allreduceArray(Object, Operand, IOperator, int, int)}, * but it's realized by rpc communication. It is suited to small data. * * @param arrData data array, each process have the same length. * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param <T> * @return arrData[i](rank) = reduce(arrData[i](rank_0), arrData[i](rank_1), ..., arrData[i](rank_slavenumber-1)), * @throws Mp4jException */ public <T> T allreduceArrayRpc(T arrData, Operand operand, IOperator operator) throws Mp4jException { if (slaveNum == 1) { return arrData; } if (operand instanceof DoubleOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); double []retArray = (double[])ret.get(); double []thisArray = (double[])arrData; IDoubleOperator doubleOperator = (IDoubleOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = doubleOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else if (operand instanceof FloatOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); float []retArray = (float[])ret.get(); float []thisArray = (float[])arrData; IFloatOperator floatOperator = (IFloatOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = floatOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else if (operand instanceof IntOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); int []retArray = (int[])ret.get(); int []thisArray = (int[])arrData; IIntOperator intOperator = (IIntOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = intOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else if (operand instanceof LongOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); long []retArray = (long[])ret.get(); long []thisArray = (long[])arrData; ILongOperator longOperator = (ILongOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = longOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else if (operand instanceof ObjectOperand) { ObjectOperand objectOperand = (ObjectOperand)operand; objectOperand.setOperator(operator); ArrayPrimitiveWritable ret = server.arrayAllReduce(new ArrayPrimitiveWritable(objectOperand.convertToBytes(arrData)), rank); Input input = new Input(new ByteArrayInputStream((byte[])ret.get())); T retx = (T)objectOperand.readFromBytes(arrData, input, slaveNum); input.close(); return retx; } else if (operand instanceof StringOperand) { String[] thisArray = (String []) arrData; ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1024); Output output = new Output(byteArrayOutputStream); KryoUtils.getKryo().writeClassAndObject(output, thisArray); output.close(); ArrayPrimitiveWritable ret = server.arrayAllReduce(new ArrayPrimitiveWritable(byteArrayOutputStream.toByteArray()), rank); Input input = new Input(new ByteArrayInputStream((byte[])ret.get())); String []retArray; IStringOperator stringOperator = (IStringOperator) operator; for (int i = 0; i < slaveNum; i++) { retArray = (String[])KryoUtils.getKryo().readClassAndObject(input); if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = stringOperator.apply(thisArray[j], retArray[j]); } } } input.close(); } else if (operand instanceof ShortOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); short []retArray = (short[])ret.get(); short []thisArray = (short[])arrData; IShortOperator shortOperator = (IShortOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = shortOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else if (operand instanceof ByteOperand) { ArrayPrimitiveWritable ret = server.primitiveArrayAllReduce(new ArrayPrimitiveWritable(arrData), rank); byte []retArray = (byte[])ret.get(); byte []thisArray = (byte[])arrData; IByteOperator byteOperator = (IByteOperator) operator; int idx = 0; for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = retArray[idx + j]; } } else { for (int j = 0; j < thisArray.length; j++) { thisArray[j] = byteOperator.apply(thisArray[j], retArray[idx + j]); } } idx += thisArray.length; } } else { throw new Mp4jException("unknown operand:" + operand); } return arrData; } /** * Similar with {@link #allreduceArray(Object, Operand, IOperator, int, int)}, * this function only allreduce just only one elements. * @param value the value for reduce * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @return all processes receive the same reduced value. * @throws Mp4jException */ public <T> T allreduce(T value, Operand operand, IOperator operator) throws Mp4jException { if (operand instanceof DoubleOperand) { double []doubleArr = new double[1]; doubleArr[0] = (Double) value; doubleArr = allreduceArray(doubleArr, operand, operator, 0, doubleArr.length); return (T)Double.valueOf(doubleArr[0]); } else if (operand instanceof FloatOperand) { float []floatArr = new float[1]; floatArr[0] = (Float) value; floatArr = allreduceArray(floatArr, operand, operator, 0, floatArr.length); return (T)Float.valueOf(floatArr[0]); } else if (operand instanceof IntOperand) { int []intArr = new int[1]; intArr[0] = (Integer) value; intArr = allreduceArray(intArr, operand, operator, 0, intArr.length); return (T)Integer.valueOf(intArr[0]); } else if (operand instanceof LongOperand) { long []longArr = new long[1]; longArr[0] = (Long) value; longArr = allreduceArray(longArr, operand, operator, 0, longArr.length); return (T)Long.valueOf(longArr[0]); } else if (operand instanceof ObjectOperand) { T []objectArr = (T[]) Array.newInstance(value.getClass(), 1); objectArr[0] = value; objectArr = allreduceArray(objectArr, operand, operator, 0, objectArr.length); return objectArr[0]; } else if (operand instanceof StringOperand) { String []stringArr = new String[1]; stringArr[0] = (String) value; stringArr = allreduceArray(stringArr, operand, operator, 0, stringArr.length); return (T)stringArr[0]; } else if (operand instanceof ShortOperand) { short []shortArr = new short[1]; shortArr[0] = (Short) value; shortArr = allreduceArray(shortArr, operand, operator, 0, shortArr.length); return (T)Short.valueOf(shortArr[0]); } else if (operand instanceof ByteOperand) { byte []byteArr = new byte[1]; byteArr[0] = (Byte) value; byteArr = allreduceArray(byteArr, operand, operator, 0, byteArr.length); return (T)Byte.valueOf(byteArr[0]); } else { throw new Mp4jException("unknown operand:" + operand); } } /** * Similar with {@link #allreduce(Object, Operand, IOperator)}, * but it's realized by rpc communication. It is suited to small data. * * @param value the value for reduce * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @param <T> * @return * @throws Mp4jException */ public <T> T allreduceRpc(T value, Operand operand, IOperator operator) throws Mp4jException { if (operand instanceof DoubleOperand) { double []doubleArr = new double[1]; doubleArr[0] = (Double) value; doubleArr = allreduceArrayRpc(doubleArr, operand, operator); return (T)Double.valueOf(doubleArr[0]); } else if (operand instanceof FloatOperand) { float []floatArr = new float[1]; floatArr[0] = (Float) value; floatArr = allreduceArrayRpc(floatArr, operand, operator); return (T)Float.valueOf(floatArr[0]); } else if (operand instanceof IntOperand) { int []intArr = new int[1]; intArr[0] = (Integer) value; intArr = allreduceArrayRpc(intArr, operand, operator); return (T)Integer.valueOf(intArr[0]); } else if (operand instanceof LongOperand) { long []longArr = new long[1]; longArr[0] = (Long) value; longArr = allreduceArrayRpc(longArr, operand, operator); return (T)Long.valueOf(longArr[0]); } else if (operand instanceof ObjectOperand) { T []objectArr = (T[]) Array.newInstance(value.getClass(), 1); objectArr[0] = value; objectArr = allreduceArrayRpc(objectArr, operand, operator); return objectArr[0]; } else if (operand instanceof StringOperand) { String []stringArr = new String[1]; stringArr[0] = (String) value; stringArr = allreduceArrayRpc(stringArr, operand, operator); return (T)stringArr[0]; } else if (operand instanceof ShortOperand) { short []shortArr = new short[1]; shortArr[0] = (Short) value; shortArr = allreduceArrayRpc(shortArr, operand, operator); return (T)Short.valueOf(shortArr[0]); } else if (operand instanceof ByteOperand) { byte []byteArr = new byte[1]; byteArr[0] = (Byte) value; byteArr = allreduceArrayRpc(byteArr, operand, operator); return (T)Byte.valueOf(byteArr[0]); } else { throw new Mp4jException("unknown operand:" + operand); } } /** * Similar with {@link #allreduceArray(Object, Operand, IOperator, int, int)}, * the container of this function is map, the values which have same key will be reduced, * different processes can contain the same and different keys. * @param mapData map to be reduced * @param operand operand(there are 8 operands in {@link com.fenbi.mp4j.operand.Operands}) * @param operator operator(different operands provide different operator in {@link com.fenbi.mp4j.operator.Operators}) * @return key value map, the values which have same key will be reduced together. * @throws Mp4jException */ public <T> Map<String, T> allreduceMap(Map<String, T> mapData, Operand operand, IOperator operator) throws Mp4jException { if (slaveNum == 1) { return mapData; } try { List<Map<String, T>> listMapData = new ArrayList<>(slaveNum); for (int i = 0; i < slaveNum; i++) { listMapData.add(new HashMap<>(Math.max((int)((mapData.size() / slaveNum) * 1.2), 1))); } for (Map.Entry<String, T> entry : mapData.entrySet()) { String key = entry.getKey(); T val = entry.getValue(); int idx = key.hashCode() % slaveNum; if (idx < 0) { LOG.info(key + ", code:" + idx); idx += slaveNum; } listMapData.get(idx).put(key, val); } Map<String, T> reduceScatteredMap = reduceScatterMap(listMapData, operand, operator); List<Map<String, T>> allreducedMapList = allgatherMap(reduceScatteredMap, operand); Map<String, T> retMap = allreducedMapList.get(0); for (int i = 1; i < allreducedMapList.size(); i++) { Map<String, T> retMapTemp = allreducedMapList.get(i); for (Map.Entry<String, T> entry : retMapTemp.entrySet()) { retMap.put(entry.getKey(), entry.getValue()); } } return retMap; } catch (Exception e) { throw new Mp4jException(e); } } /** * Set union, the set with the same key will be reduced(union) together. * @param mapData map set data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return the set with the same key will be reduced together. * @throws Mp4jException */ public <T> Map<String, Set<T>> allreduceMapSetUnion(Map<String, Set<T>> mapData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jSetSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<Set<T>>() { @Override public Set<T> apply(Set<T> o1, Set<T> o2) { for (T val : o2) { o1.add(val); } return o1; } }; return allreduceMap(mapData, operand, operator); } /** * Set union * @param setData set data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return set union result * @throws Mp4jException */ public <T> Set<T> allreduceSetUnion(Set<T> setData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, Set<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", setData); return allreduceMapSetUnion(mapTemp, elementSerializer, elementType).get("key"); } /** * Set intersection, the set with the same key will be reduced(intersect) together. * @param mapData map set data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return the set with the same key will be reduced(intersect) together. * @throws Mp4jException */ public <T> Map<String, Set<T>> allreduceMapSetIntersection(Map<String, Set<T>> mapData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jSetSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<Set<T>>() { @Override public Set<T> apply(Set<T> o1, Set<T> o2) { o1.retainAll(o2); return o1; } }; return allreduceMap(mapData, operand, operator); } /** * Set intersection * @param setData set data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return set intersected result * @throws Mp4jException */ public <T> Set<T> allreduceSetIntersection(Set<T> setData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, Set<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", setData); return allreduceMapSetIntersection(mapTemp, elementSerializer, elementType).get("key"); } /** * List concat, the lists with the same key will be reduced(concat) together. * @param mapData map list data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return the lists with the same key will be reduced(concat) together. * @throws Mp4jException */ public <T> Map<String, List<T>> allreduceMapListConcat(Map<String, List<T>> mapData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Operand operand = Operands.OBJECT_OPERAND(new Mp4jListSerializer<>(elementSerializer, elementType), elementType); IOperator operator = new IObjectOperator<List<T>>() { @Override public List<T> apply(List<T> o1, List<T> o2) { for (T val : o2) { o1.add(val); } return o1; } }; return allreduceMap(mapData, operand, operator); } /** * list concat * @param listData list data * @param elementSerializer element object Kryo serializer * @param elementType element object class * @return list concated result * @throws Mp4jException */ public <T> List<T> allreduceListConcat(List<T> listData, Serializer<T> elementSerializer, Class<T> elementType) throws Mp4jException { Map<String, List<T>> mapTemp = new HashMap<>(1); mapTemp.put("key", listData); return allreduceMapListConcat(mapTemp, elementSerializer, elementType).get("key"); } private Input getInput() throws Mp4jException { Socket recvDataSock = null; Input input = null; try { recvDataSock = getRecvDataSocket(); input = new Input(recvDataSock.getInputStream()); LOG.debug("getInput:" + recvDataSock.isClosed()); } catch (Exception e) { throw new Mp4jException(e); } return input; } private Output getOutput(int targetRank) throws Mp4jException { Socket sendDataSock = null; Output output = null; try { sendDataSock = getSendDataSocket(targetRank); output = new Output(sendDataSock.getOutputStream()); LOG.debug("getOutput:" + sendDataSock.isClosed()); } catch (Exception e) { throw new Mp4jException(e); } return output; } }