/** * * Copyright (c) 2017 ytk-mp4j https://github.com/yuantiku * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package com.fenbi.mp4j.operand; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.esotericsoftware.kryo.serializers.DeflateSerializer; import com.fenbi.mp4j.exception.Mp4jException; import com.fenbi.mp4j.meta.ArrayMetaData; import com.fenbi.mp4j.meta.MapMetaData; import com.fenbi.mp4j.meta.MetaData; import com.fenbi.mp4j.operator.*; import com.fenbi.mp4j.utils.KryoUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.lang.reflect.Array; import java.util.*; /** * @author xialong */ public class ObjectOperand<T> extends Operand { public static final Logger LOG = LoggerFactory.getLogger(ObjectOperand.class); public IObjectOperator<T> operator; Serializer<T> serializer; Class type; public ObjectOperand(Serializer<T> serializer, Class type) { this.serializer = serializer; this.type = type; } public Serializer<T> getSerializer() { return serializer; } public Class getType() { return type; } public byte[] convertToBytes(Object dataArray) { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1024); Output output = new Output(byteArrayOutputStream); T[] array = (T[]) dataArray; Kryo kryo = KryoUtils.getKryo(); for (int i = 0; i < array.length; i++) { serializer.write(kryo, output, array[i]); } output.close(); return byteArrayOutputStream.toByteArray(); } public T[] readFromBytes(Object origin, Input input, int slaveNum) { T[] originArr = (T[])origin; Kryo kryo = KryoUtils.getKryo(); for (int i = 0; i < slaveNum; i++) { if (i == 0) { for (int j = 0; j < originArr.length; j++) { originArr[j] = (T)serializer.read(kryo, input, type); } } else { for (int j = 0; j < originArr.length; j++) { T readed = (T)serializer.read(kryo, input, type); originArr[j] = operator.apply(originArr[j], readed); } } } return originArr; } public static class Mp4jObjectArraySerializer<T> extends Serializer<ArrayMetaData<T[]>> { ArrayMetaData<T[]> arrayMetaData; ArrayMetaData<T[]> thatArrMetaData; Serializer<T> serializer; Class<T> type; public Mp4jObjectArraySerializer(ArrayMetaData<T[]> arrayMetaData, Serializer<T> serializer, Class<T> type) { this.setAcceptsNull(true); this.arrayMetaData = arrayMetaData; this.serializer = serializer; this.type = type; } public ArrayMetaData<T[]> getThatArrMetaData() { return thatArrMetaData; } public void write(Kryo kryo, Output output, ArrayMetaData<T[]> object) { try { T []arrData = arrayMetaData.getArrData(); arrayMetaData.send(output); int arrSegNum = arrayMetaData.getSegNum(); for (int i = 0; i < arrSegNum; i++) { int from = arrayMetaData.getFrom(i); int to = arrayMetaData.getTo(i); for (int j = from; j < to; j++) { serializer.write(kryo, output, arrData[j]); } } } catch (IOException e) { LOG.error("double array write exception", e); System.exit(1); } } public ArrayMetaData<T[]> read(Kryo kryo, Input input, Class<ArrayMetaData<T[]>> type) { try { T []arrData = arrayMetaData.getArrData(); thatArrMetaData = arrayMetaData.recv(input); int arrSegNum = thatArrMetaData.getSegNum(); for (int i = 0; i < arrSegNum; i++) { int from = thatArrMetaData.getFrom(i); int to = thatArrMetaData.getTo(i); for (int j = from; j < to; j++) { arrData[j] = serializer.read(kryo, input, this.type); } } thatArrMetaData.setArrData(arrData); } catch (IOException e) { LOG.error("double array read exception", e); System.exit(1); } return thatArrMetaData; } } public static class Mp4jObjectMapSerializer<T> extends Serializer<MapMetaData<T>> { MapMetaData<T> mapMetaData; MapMetaData<T> thatMapMetaData; Serializer<T> serializer; Class<T> type; public Mp4jObjectMapSerializer(MapMetaData<T> mapMetaData, Serializer<T> serializer, Class<T> type) { this.setAcceptsNull(true); this.mapMetaData = mapMetaData; this.serializer = serializer; this.type = type; } public MapMetaData<T> getThatMapMetaData() { return thatMapMetaData; } public void write(Kryo kryo, Output output, MapMetaData<T> object) { try { List<Map<String, T>> mapDataList = mapMetaData.getMapDataList(); mapMetaData.send(output); int mapSegNum = mapMetaData.getSegNum(); for (int i = 0; i < mapSegNum; i++) { Map<String, T> mapData = mapDataList.get(i); for (Map.Entry<String, T> entry : mapData.entrySet()) { output.writeString(entry.getKey()); serializer.write(kryo, output, entry.getValue()); } if (mapMetaData.getCollective() == Collective.GATHER || mapMetaData.getCollective() == Collective.SCATTER || mapMetaData.getCollective() == Collective.REDUCE_SCATTER) { mapData.clear(); } } } catch (IOException e) { LOG.error("double array write exception", e); System.exit(1); } } @Override public MapMetaData<T> read(Kryo kryo, Input input, Class<MapMetaData<T>> type) { try { thatMapMetaData = mapMetaData.recv(input); int thatMapSegNum = thatMapMetaData.getSegNum(); List<Map<String, T>> mapDataList = new ArrayList<>(thatMapSegNum); thatMapMetaData.setMapDataList(mapDataList); for (int i = 0; i < thatMapSegNum; i++) { int dataNum = thatMapMetaData.getDataNum(i); Map<String, T> mapData = new HashMap<>(dataNum); mapDataList.add(mapData); for (int j = 0; j < dataNum; j++) { String key = input.readString(); T val = serializer.read(kryo, input, this.type); mapData.put(key, val); } } } catch (IOException e) { LOG.error("double array read exception", e); System.exit(1); } return thatMapMetaData; } } public static class Mp4jObjectArrayReduceSerializer<T> extends Serializer<ArrayMetaData<T[]>> { ArrayMetaData<T[]> arrayMetaData; ArrayMetaData<T[]> thatArrMetaData; IObjectOperator<T> operator; Serializer<T> serializer; Class<T> type; public Mp4jObjectArrayReduceSerializer(ArrayMetaData<T[]> arrayMetaData, IObjectOperator<T> operator, Serializer<T> serializer, Class<T> type ) { this.setAcceptsNull(true); this.arrayMetaData = arrayMetaData; this.operator = operator; this.serializer = serializer; this.type = type; } public ArrayMetaData<T[]> getThatArrMetaData() { return thatArrMetaData; } public void write(Kryo kryo, Output output, ArrayMetaData<T[]> object) { } public ArrayMetaData<T[]> read(Kryo kryo, Input input, Class<ArrayMetaData<T[]>> type) { try { T []arrData = arrayMetaData.getArrData(); thatArrMetaData = arrayMetaData.recv(input); int arrSegNum = thatArrMetaData.getSegNum(); for (int i = 0; i < arrSegNum; i++) { int from = thatArrMetaData.getFrom(i); int to = thatArrMetaData.getTo(i); for (int j = from; j < to; j++) { arrData[j] = operator.apply(arrData[j], serializer.read(kryo, input, this.type)); } } thatArrMetaData.setArrData(arrData); } catch (IOException e) { LOG.error("double array read exception", e); System.exit(1); } return thatArrMetaData; } } public static class Mp4jObjectMapReduceSerializer<T> extends Serializer<MapMetaData<T>> { MapMetaData<T> mapMetaData; MapMetaData<T> thatMapMetaData; IObjectOperator<T> operator; Serializer<T> serializer; Class<T> type; public Mp4jObjectMapReduceSerializer(MapMetaData<T> mapMetaData, IObjectOperator<T> operator, Serializer<T> serializer, Class<T> type ) { this.setAcceptsNull(true); this.mapMetaData = mapMetaData; this.operator = operator; this.serializer = serializer; this.type = type; } public MapMetaData<T> getThatMapMetaData() { return thatMapMetaData; } public void write(Kryo kryo, Output output, MapMetaData<T> object) { } public MapMetaData<T> read(Kryo kryo, Input input, Class<MapMetaData<T>> type) { try { thatMapMetaData = mapMetaData.recv(input); int thatMapSegNum = thatMapMetaData.getSegNum(); List<Map<String, T>> thatMapListData = new ArrayList<>(thatMapSegNum); List<Integer> thatDataNums = new ArrayList<>(thatMapSegNum); for (int i = 0; i < thatMapSegNum; i++) { Map<String, T> thisMapData = mapMetaData.getMapDataList().get(i); int dataNum = thatMapMetaData.getDataNum(i); for (int j = 0; j < dataNum; j++) { String key = input.readString(); T val = serializer.read(kryo, input, this.type); T thisVal = thisMapData.get(key); if (thisVal == null) { thisMapData.put(key, val); } else { thisMapData.put(key, operator.apply(thisVal, val)); } } thatMapListData.add(thisMapData); thatDataNums.add(thisMapData.size()); } thatMapMetaData.setMapDataList(thatMapListData); thatMapMetaData.setDataNums(thatDataNums); } catch (IOException e) { LOG.error("double array read exception", e); System.exit(1); } return thatMapMetaData; } } @Override public void send(Output output, MetaData metaData) throws IOException, Mp4jException { Kryo kryo = KryoUtils.getKryo(); try { switch (container) { case ARRAY: Serializer arrSerializer = new ObjectOperand.Mp4jObjectArraySerializer( metaData.convertToArrayMetaData(), serializer, type); if (compress) { arrSerializer = new DeflateSerializer(arrSerializer); } arrSerializer.write(kryo, output, null); break; case MAP: Serializer mapSerializer = new ObjectOperand.Mp4jObjectMapSerializer( metaData.convertToMapMetaData(), serializer, type); if (compress) { mapSerializer = new DeflateSerializer(mapSerializer); } mapSerializer.write(kryo, output, null); break; default: throw new Mp4jException("unsupported container:" + container); } } catch (Exception e) { LOG.error("send exception", e); throw new Mp4jException(e); } finally { output.close(); } } @Override public MetaData recv(Input input, MetaData metaData) throws IOException, Mp4jException { MetaData retMetaData = null; Kryo kryo = KryoUtils.getKryo(); try { switch (collective) { case GATHER: case SCATTER: case ALL_GATHER: switch (container) { case ARRAY: Serializer<ArrayMetaData<T[]>> arrSerializer = new ObjectOperand.Mp4jObjectArraySerializer( metaData.convertToArrayMetaData(), serializer, type); if (compress) { arrSerializer = new DeflateSerializer(arrSerializer); } retMetaData = arrSerializer.read(kryo, input, null); break; case MAP: Serializer<MapMetaData<T>> mapSerializer = new ObjectOperand.Mp4jObjectMapSerializer( metaData.convertToMapMetaData(), serializer, type); if (compress) { mapSerializer = new DeflateSerializer(mapSerializer); } retMetaData = mapSerializer.read(kryo, input, null); break; default: throw new Mp4jException("unsupported container:" + container); } break; case REDUCE_SCATTER: switch (container) { case ARRAY: Serializer<ArrayMetaData<T[]>> arrSerializer = new ObjectOperand.Mp4jObjectArrayReduceSerializer( metaData.convertToArrayMetaData(), operator, serializer, type); if (compress) { arrSerializer = new DeflateSerializer(arrSerializer); } retMetaData = arrSerializer.read(kryo, input, null); break; case MAP: Serializer<MapMetaData<T>> mapSerializer = new ObjectOperand.Mp4jObjectMapReduceSerializer( metaData.convertToMapMetaData(), operator, serializer, type); if (compress) { mapSerializer = new DeflateSerializer(mapSerializer); } retMetaData = mapSerializer.read(kryo, input, null); break; default: throw new Mp4jException("unsupported container:" + container); } break; default: throw new Mp4jException("unsupported basic collective:" + collective); } } catch (Exception e) { throw new Mp4jException(e); } finally { input.close(); } return retMetaData; } @Override public void setOperator(IOperator operator) { this.operator = (IObjectOperator)operator; } @Override public void threadCopy(MetaData fromMetaData, MetaData toMetaData) throws Mp4jException { if (fromMetaData instanceof ArrayMetaData) { ArrayMetaData<T[]> fromArrayMetaData = fromMetaData.convertToArrayMetaData(); ArrayMetaData<T[]> toArrayMetaData = toMetaData.convertToArrayMetaData(); T[] fromArrayData = fromArrayMetaData.getArrData(); T[] toArrayData = toArrayMetaData.getArrData(); int segNum = fromArrayMetaData.getSegNum(); for (int i = 0; i < segNum; i++) { int from = fromArrayMetaData.getFrom(i) >= 0 ? fromArrayMetaData.getFrom(i) : 0; int to = fromArrayMetaData.getTo(i) >= 0 ? fromArrayMetaData.getTo(i) : toArrayData.length; for (int j = from; j < to; j++) { toArrayData[j] = fromArrayData[j]; } } } else if (fromMetaData instanceof MapMetaData) { toMetaData.setMapDataList(fromMetaData.getMapDataList()); } else { throw new Mp4jException("unknown format metadata!" + fromMetaData); } } @Override public void threadArrayAllCopy(MetaData fromMetaData, MetaData toMetaData) throws Mp4jException { if (fromMetaData instanceof ArrayMetaData) { ArrayMetaData<T[]> fromArrayMetaData = fromMetaData.convertToArrayMetaData(); ArrayMetaData<T[]> toArrayMetaData = toMetaData.convertToArrayMetaData(); T[] fromArrayData = fromArrayMetaData.getArrData(); T[] toArrayData = toArrayMetaData.getArrData(); System.arraycopy(fromArrayData, 0, toArrayData, 0, fromArrayData.length); } else { throw new Mp4jException("threadArrayAllCopy unsupport format metadata!" + fromMetaData); } } @Override public void threadMerge(MetaData fromMetaData, MetaData toMetaData) throws Mp4jException { if (fromMetaData instanceof ArrayMetaData) { ArrayMetaData<T[]> fromArrayMetaData = fromMetaData.convertToArrayMetaData(); ArrayMetaData<T[]> toArrayMetaData = toMetaData.convertToArrayMetaData(); T[] fromArrayData = fromArrayMetaData.getArrData(); T[] toArrayData = toArrayMetaData.getArrData(); int segNum = fromArrayMetaData.getSegNum(); for (int i = 0; i < segNum; i++) { int from = fromArrayMetaData.getFrom(i) >= 0 ? fromArrayMetaData.getFrom(i) : 0; int to = fromArrayMetaData.getTo(i) >= 0 ? fromArrayMetaData.getTo(i) : toArrayData.length; for (int j = from; j < to; j++) { toArrayData[j] = fromArrayData[j]; } } toArrayMetaData.setSum(toArrayMetaData.getSum() + fromArrayMetaData.getSum()); toArrayMetaData.setSegNum(toArrayMetaData.getSegNum() + fromArrayMetaData.getSegNum()); toArrayMetaData.getRanks().addAll(fromArrayMetaData.getRanks()); toArrayMetaData.getSegFroms().addAll(fromArrayMetaData.getSegFroms()); toArrayMetaData.getSegTos().addAll(fromArrayMetaData.getSegTos()); } else if (fromMetaData instanceof MapMetaData) { MapMetaData fromMapMetaData = fromMetaData.convertToMapMetaData(); MapMetaData toMapMetaData = toMetaData.convertToMapMetaData(); toMetaData.setSum(fromMetaData.getSum() + toMetaData.getSum()); toMetaData.setSegNum(1); toMapMetaData.getRanks().addAll(fromMapMetaData.getRanks()); List<Map<String, T>> fromMapDataList = fromMetaData.getMapDataList(); Map<String, T> thisMap = (Map<String, T>)toMapMetaData.getMapDataList().get(0); for (Map<String, T> thatMap : fromMapDataList) { for (Map.Entry<String, T> entry : thatMap.entrySet()) { thisMap.put(entry.getKey(), entry.getValue()); } } toMapMetaData.setDataNums(Arrays.asList(thisMap.size())); } else { throw new Mp4jException("unknown format metadata!" + fromMetaData); } } @Override public void threadReduce(MetaData fromMetaData, MetaData toMetaData) throws Mp4jException { if (fromMetaData instanceof ArrayMetaData) { ArrayMetaData<T[]> fromArrayMetaData = fromMetaData.convertToArrayMetaData(); ArrayMetaData<T[]> toArrayMetaData = toMetaData.convertToArrayMetaData(); T []fromArrayData = fromArrayMetaData.getArrData(); T []toArrayData = toArrayMetaData.getArrData(); int segNum = fromArrayMetaData.getSegNum(); for (int i = 0; i < segNum; i++) { int from = fromArrayMetaData.getFrom(i) >= 0 ? fromArrayMetaData.getFrom(i) : 0; int to = fromArrayMetaData.getTo(i) >= 0 ? fromArrayMetaData.getTo(i) : toArrayData.length; for (int j = from; j < to; j++) { toArrayData[j] = operator.apply(toArrayData[j], fromArrayData[j]); } } toArrayMetaData.setSum(toArrayMetaData.getSum() + fromArrayMetaData.getSum()); toArrayMetaData.setSegNum(1); } else if (fromMetaData instanceof MapMetaData) { MapMetaData<T> fromMapMetaData = fromMetaData.convertToMapMetaData(); MapMetaData<T> toMapMetaData = toMetaData.convertToMapMetaData(); List<Map<String, T>> fromMapDataList = fromMapMetaData.getMapDataList(); List<Map<String, T>> toMapDataList = toMapMetaData.getMapDataList(); for (int i = 0; i < fromMapDataList.size(); i++) { Map<String, T> fromMapData = fromMapDataList.get(i); Map<String, T> toMapData = toMapDataList.get(i); for (Map.Entry<String, T> entry : fromMapData.entrySet()) { String key = entry.getKey(); T val = entry.getValue(); T toVal = toMapData.get(key); if (toVal == null) { toMapData.put(key, val); } else { toMapData.put(key, operator.apply(val, toVal)); } } } toMetaData.setSum(fromMetaData.getSum() + toMetaData.getSum()); } else { throw new Mp4jException("unknown format metadata!" + fromMetaData); } } }