package io.arabesque.computation.comm; import io.arabesque.cache.LZ4ObjectCache; import io.arabesque.computation.BasicComputation; import io.arabesque.computation.Computation; import io.arabesque.computation.MasterExecutionEngine; import io.arabesque.computation.WorkerContext; import io.arabesque.embedding.Embedding; import io.arabesque.pattern.Pattern; import org.apache.giraph.graph.Vertex; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import java.io.IOException; import java.util.Iterator; public class CacheCommunicationStrategy<O extends Embedding> extends CommunicationStrategy<O> { private String partitionCounterKey; private String groupCounterKey; private long newIdCounter; private long newGroupCounter; private LZ4ObjectCache[] outputCaches; private IntWritable reusableDestinationId; private MessageWrapper reusableMessageWrapper; private Iterator<MessageWrapper> messageIterator; LZ4ObjectCache currentObjectCache; private long totalSizeEmbeddingsProcessed; private int numberOfPartitions; private boolean patternAggFilterDefined; private Computation computation; @Override public void initialize(int phase) { super.initialize(phase); WorkerContext workerContext = getWorkerContext(); numberOfPartitions = getWorkerContext().getNumberPartitions(); reusableDestinationId = new IntWritable(); reusableMessageWrapper = new MessageWrapper(); int partitionId = getExecutionEngine().getPartitionId(); partitionCounterKey = "newIdCounter" + partitionId; groupCounterKey = "groupCounter" + partitionId; newIdCounter = workerContext.getLongData(partitionCounterKey); newGroupCounter = workerContext.getLongData(groupCounterKey); if (newIdCounter == -1) { long countersPerPartition = Long.MAX_VALUE / numberOfPartitions; newIdCounter = countersPerPartition * partitionId; newGroupCounter = countersPerPartition * partitionId; String partitionCounterMaxKey = "newIdCounterMax" + partitionId; long maxCounterValueForPartition = newIdCounter + countersPerPartition - 1; workerContext.setData(partitionCounterMaxKey, maxCounterValueForPartition); } outputCaches = new LZ4ObjectCache[numberOfPartitions]; for (int i = 0; i < outputCaches.length; ++i) { outputCaches[i] = new LZ4ObjectCache(); } totalSizeEmbeddingsProcessed = 0; } @Override public int getNumPhases() { return 1; } @Override public void flush() { for (int i = 0; i < outputCaches.length; ++i) { LZ4ObjectCache cache = outputCaches[i]; reusableDestinationId.set(i); flushCache(i, cache); } } private void flushCache(int partitionId, LZ4ObjectCache outputCache) { if (!outputCache.isEmpty()) { reusableDestinationId.set(partitionId); reusableMessageWrapper.setMessage(outputCache); sendMessage(reusableDestinationId, reusableMessageWrapper); outputCache.reset(); } } @Override public void finish() { flush(); WorkerContext workerContext = getWorkerContext(); workerContext.setData(partitionCounterKey, newIdCounter); workerContext.setData(groupCounterKey, newGroupCounter); LongWritable longWritable = new LongWritable(); longWritable.set(totalSizeEmbeddingsProcessed); getExecutionEngine().aggregate(MasterExecutionEngine.AGG_PROCESSED_SIZE_CACHE, longWritable); } @Override public void startComputation(Vertex<IntWritable, NullWritable, NullWritable> vertex, Iterable<MessageWrapper> messages) { super.startComputation(vertex, messages); messageIterator = messages.iterator(); computation = getExecutionEngine().getComputation(); try { patternAggFilterDefined = computation.getClass().getMethod("aggregationFilter", Pattern.class).getDeclaringClass() != BasicComputation.class; } catch (NoSuchMethodException e) { throw new RuntimeException(e); } } @Override public O getNextInboundEmbedding() { if (currentObjectCache == null) { if (messageIterator.hasNext()) { currentObjectCache = messageIterator.next().getMessage(); } if (currentObjectCache == null) { return null; } else { currentObjectCache.prepareForIteration(); } } while (currentObjectCache.hasNext()) { O embedding = (O) currentObjectCache.next(); if (!patternAggFilterDefined || computation.aggregationFilter(embedding.getPattern())) { return embedding; } } totalSizeEmbeddingsProcessed += currentObjectCache.getByteArrayOutputCache().getPos(); currentObjectCache = null; return getNextInboundEmbedding(); } @Override public void addOutboundEmbedding(O expansion) { int destinationPartition = (int) ((newIdCounter++) % numberOfPartitions); LZ4ObjectCache outputCache = outputCaches[destinationPartition]; try { outputCache.addObject(expansion); } catch (IOException e) { throw new RuntimeException("Unable to add outbound embedding", e); } if (outputCache.overThreshold()) { flushCache(destinationPartition, outputCache); } } }