package org.apache.tez.runtime.library.common.shuffle; import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.CompressionInputStream; import org.apache.hadoop.io.compress.CompressionOutputStream; import org.apache.hadoop.io.compress.Compressor; import org.apache.hadoop.io.compress.Decompressor; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.common.counters.TezCounters; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.runtime.api.Event; import org.apache.tez.runtime.api.InputContext; import org.apache.tez.runtime.api.OutputContext; import org.apache.tez.runtime.api.events.CompositeDataMovementEvent; import org.apache.tez.runtime.api.events.VertexManagerEvent; import org.apache.tez.runtime.api.impl.ExecutionContextImpl; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.common.InputAttemptIdentifier; import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils.FetchStatsLogger; import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord; import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord; import org.apache.tez.runtime.library.partitioner.HashPartitioner; import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Matchers; import org.slf4j.Logger; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.BitSet; import java.util.concurrent.ThreadLocalRandom; import java.util.List; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ public class TestShuffleUtils { private OutputContext outputContext; private Configuration conf; private FileSystem localFs; private Path workingDir; private InputContext createTezInputContext() { ApplicationId applicationId = ApplicationId.newInstance(1, 1); InputContext inputContext = mock(InputContext.class); doReturn(applicationId).when(inputContext).getApplicationId(); doReturn("sourceVertex").when(inputContext).getSourceVertexName(); when(inputContext.getCounters()).thenReturn(new TezCounters()); return inputContext; } private OutputContext createTezOutputContext() throws IOException { ApplicationId applicationId = ApplicationId.newInstance(1, 1); OutputContext outputContext = mock(OutputContext.class); ExecutionContextImpl executionContext = mock(ExecutionContextImpl.class); doReturn("localhost").when(executionContext).getHostName(); doReturn(executionContext).when(outputContext).getExecutionContext(); DataOutputBuffer serviceProviderMetaData = new DataOutputBuffer(); serviceProviderMetaData.writeInt(80); doReturn(ByteBuffer.wrap(serviceProviderMetaData.getData())).when(outputContext) .getServiceProviderMetaData (conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT)); doReturn(1).when(outputContext).getTaskVertexIndex(); doReturn(1).when(outputContext).getOutputIndex(); doReturn(0).when(outputContext).getDAGAttemptNumber(); doReturn("destVertex").when(outputContext).getDestinationVertexName(); when(outputContext.getCounters()).thenReturn(new TezCounters()); return outputContext; } @Before public void setup() throws Exception { conf = new Configuration(); outputContext = createTezOutputContext(); conf.set("fs.defaultFS", "file:///"); localFs = FileSystem.getLocal(conf); workingDir = new Path( new Path(System.getProperty("test.build.data", "/tmp")), TestShuffleUtils.class.getName()) .makeQualified(localFs.getUri(), localFs.getWorkingDirectory()); String localDirs = workingDir.toString(); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, Text.class.getName()); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, Text.class.getName()); conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, HashPartitioner.class.getName()); conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, localDirs); } private Path createIndexFile(int numPartitions, boolean allEmptyPartitions) throws IOException { Path path = new Path(workingDir, "file.index.out"); TezSpillRecord spillRecord = new TezSpillRecord(numPartitions); long startOffset = 0; long partLen = 200; //compressed for(int i=0;i<numPartitions;i++) { long rawLen = ThreadLocalRandom.current().nextLong(100, 200); if (i % 2 == 0 || allEmptyPartitions) { rawLen = 0; //indicates empty partition, see TEZ-3605 } TezIndexRecord indexRecord = new TezIndexRecord(startOffset, rawLen, partLen); startOffset += partLen; spillRecord.putIndex(indexRecord, i); } spillRecord.writeToFile(path, conf, FileSystem.getLocal(conf).getRaw()); return path; } @Test public void testGenerateOnSpillEvent() throws Exception { List<Event> events = Lists.newLinkedList(); Path indexFile = createIndexFile(10, false); boolean finalMergeEnabled = false; boolean isLastEvent = false; int spillId = 0; int physicalOutputs = 10; String pathComponent = "/attempt_x_y_0/file.out"; String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT); ShuffleUtils.generateEventOnSpill(events, finalMergeEnabled, isLastEvent, outputContext, spillId, new TezSpillRecord(indexFile, conf), physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater()); Assert.assertTrue(events.size() == 1); Assert.assertTrue(events.get(0) instanceof CompositeDataMovementEvent); CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0); Assert.assertTrue(cdme.getCount() == physicalOutputs); Assert.assertTrue(cdme.getSourceIndexStart() == 0); ByteBuffer payload = cdme.getUserPayload(); ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto = ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom(payload)); Assert.assertTrue(dmeProto.getSpillId() == 0); Assert.assertTrue(dmeProto.hasLastEvent() && !dmeProto.getLastEvent()); byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto.getEmptyPartitions()); BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions); Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 5) = " + emptyPartitionsBitSet .cardinality(), emptyPartitionsBitSet.cardinality() == 5); events.clear(); } @Test public void testGenerateOnSpillEvent_With_FinalMerge() throws Exception { List<Event> events = Lists.newLinkedList(); Path indexFile = createIndexFile(10, false); boolean finalMergeEnabled = true; boolean isLastEvent = true; int spillId = 0; int physicalOutputs = 10; String pathComponent = "/attempt_x_y_0/file.out"; String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT); //normal code path where we do final merge all the time ShuffleUtils.generateEventOnSpill(events, finalMergeEnabled, isLastEvent, outputContext, spillId, new TezSpillRecord(indexFile, conf), physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater()); Assert.assertTrue(events.size() == 2); //one for VM Assert.assertTrue(events.get(0) instanceof VertexManagerEvent); Assert.assertTrue(events.get(1) instanceof CompositeDataMovementEvent); CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(1); Assert.assertTrue(cdme.getCount() == physicalOutputs); Assert.assertTrue(cdme.getSourceIndexStart() == 0); ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto = ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom( cdme.getUserPayload())); //With final merge, spill details should not be present Assert.assertFalse(dmeProto.hasSpillId()); Assert.assertFalse(dmeProto.hasLastEvent() || dmeProto.getLastEvent()); byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto .getEmptyPartitions()); BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions); Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 5) = " + emptyPartitionsBitSet .cardinality(), emptyPartitionsBitSet.cardinality() == 5); } @Test public void testGenerateOnSpillEvent_With_All_EmptyPartitions() throws Exception { List<Event> events = Lists.newLinkedList(); //Create an index file with all empty partitions Path indexFile = createIndexFile(10, true); boolean finalMergeDisabled = false; boolean isLastEvent = true; int spillId = 0; int physicalOutputs = 10; String pathComponent = "/attempt_x_y_0/file.out"; String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT); //normal code path where we do final merge all the time ShuffleUtils.generateEventOnSpill(events, finalMergeDisabled, isLastEvent, outputContext, spillId, new TezSpillRecord(indexFile, conf), physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater()); Assert.assertTrue(events.size() == 2); //one for VM Assert.assertTrue(events.get(0) instanceof VertexManagerEvent); Assert.assertTrue(events.get(1) instanceof CompositeDataMovementEvent); CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(1); Assert.assertTrue(cdme.getCount() == physicalOutputs); Assert.assertTrue(cdme.getSourceIndexStart() == 0); ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto = ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom( cdme.getUserPayload())); //spill details should be present Assert.assertTrue(dmeProto.getSpillId() == 0); Assert.assertTrue(dmeProto.hasLastEvent() && dmeProto.getLastEvent()); Assert.assertTrue(dmeProto.getPathComponent().equals("")); byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto .getEmptyPartitions()); BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions); Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 10) = " + emptyPartitionsBitSet .cardinality(), emptyPartitionsBitSet.cardinality() == 10); } @Test public void testInternalErrorTranslation() throws Exception { String codecErrorMsg = "codec failure"; CompressionInputStream mockCodecStream = mock(CompressionInputStream.class); when(mockCodecStream.read(any(byte[].class), anyInt(), anyInt())) .thenThrow(new InternalError(codecErrorMsg)); Decompressor mockDecoder = mock(Decompressor.class); CompressionCodec mockCodec = mock(ConfigurableCodecForTest.class); when(mockCodec.createDecompressor()).thenReturn(mockDecoder); when(mockCodec.createInputStream(any(InputStream.class), any(Decompressor.class))) .thenReturn(mockCodecStream); byte[] header = new byte[] { (byte) 'T', (byte) 'I', (byte) 'F', (byte) 1}; try { ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), 1024, 128, mockCodec, false, 0, mock(Logger.class), null); Assert.fail("shuffle was supposed to throw!"); } catch (IOException e) { Assert.assertTrue(e.getCause() instanceof InternalError); Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); } } @Test public void testExceptionTranslation() throws Exception { String codecErrorMsg = "codec failure"; CompressionInputStream mockCodecStream = mock(CompressionInputStream.class); when(mockCodecStream.read(any(byte[].class), anyInt(), anyInt())) .thenThrow(new IllegalArgumentException(codecErrorMsg)); Decompressor mockDecoder = mock(Decompressor.class); CompressionCodec mockCodec = mock(ConfigurableCodecForTest.class); when(mockCodec.createDecompressor()).thenReturn(mockDecoder); when(mockCodec.createInputStream(any(InputStream.class), any(Decompressor.class))) .thenReturn(mockCodecStream); byte[] header = new byte[] { (byte) 'T', (byte) 'I', (byte) 'F', (byte) 1}; try { ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), 1024, 128, mockCodec, false, 0, mock(Logger.class), null); Assert.fail("shuffle was supposed to throw!"); } catch (IOException e) { Assert.assertTrue(e.getCause() instanceof IllegalArgumentException); Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); } CompressionInputStream mockCodecStream1 = mock(CompressionInputStream.class); when(mockCodecStream1.read(any(byte[].class), anyInt(), anyInt())) .thenThrow(new SocketTimeoutException(codecErrorMsg)); CompressionCodec mockCodec1 = mock(CompressionCodec.class); when(mockCodec1.createDecompressor()).thenReturn(mockDecoder); when(mockCodec1.createInputStream(any(InputStream.class), any(Decompressor.class))) .thenReturn(mockCodecStream1); try { ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), 1024, 128, mockCodec1, false, 0, mock(Logger.class), null); Assert.fail("shuffle was supposed to throw!"); } catch (IOException e) { Assert.assertTrue(e instanceof SocketTimeoutException); Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); } CompressionInputStream mockCodecStream2 = mock(CompressionInputStream.class); when(mockCodecStream2.read(any(byte[].class), anyInt(), anyInt())) .thenThrow(new InternalError(codecErrorMsg)); CompressionCodec mockCodec2 = mock(CompressionCodec.class); when(mockCodec2.createDecompressor()).thenReturn(mockDecoder); when(mockCodec2.createInputStream(any(InputStream.class), any(Decompressor.class))) .thenReturn(mockCodecStream2); try { ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), 1024, 128, mockCodec2, false, 0, mock(Logger.class), null); Assert.fail("shuffle was supposed to throw!"); } catch (IOException e) { Assert.assertTrue(e.getCause() instanceof InternalError); Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); } } @Test public void testShuffleToDiskChecksum() throws Exception { // verify sending a stream of zeroes without checksum validation // does not trigger an exception byte[] bogusData = new byte[1000]; Arrays.fill(bogusData, (byte) 0); ByteArrayInputStream in = new ByteArrayInputStream(bogusData); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ShuffleUtils.shuffleToDisk(baos, "somehost", in, bogusData.length, 2000, mock(Logger.class), null, false, 0, false); Assert.assertArrayEquals(bogusData, baos.toByteArray()); // verify sending same stream of zeroes with validation generates an exception in.reset(); try { ShuffleUtils.shuffleToDisk(mock(OutputStream.class), "somehost", in, bogusData.length, 2000, mock(Logger.class), null, false, 0, true); Assert.fail("shuffle was supposed to throw!"); } catch (IOException e) { } } @Test public void testFetchStatsLogger() throws Exception { Logger activeLogger = mock(Logger.class); Logger aggregateLogger = mock(Logger.class); FetchStatsLogger logger = new FetchStatsLogger(activeLogger, aggregateLogger); InputAttemptIdentifier ident = new InputAttemptIdentifier(1, 1); when(activeLogger.isInfoEnabled()).thenReturn(false); for (int i = 0; i < 1000; i++) { logger.logIndividualFetchComplete(10, 100, 1000, "testType", ident); } verify(activeLogger, times(0)).info(anyString()); verify(aggregateLogger, times(1)).info(anyString(), Matchers.<Object[]>anyVararg()); when(activeLogger.isInfoEnabled()).thenReturn(true); for (int i = 0; i < 1000; i++) { logger.logIndividualFetchComplete(10, 100, 1000, "testType", ident); } verify(activeLogger, times(1000)).info(anyString()); verify(aggregateLogger, times(1)).info(anyString(), Matchers.<Object[]>anyVararg()); } /** * A codec class which implements CompressionCodec, Configurable for testing purposes. */ public static class ConfigurableCodecForTest implements CompressionCodec, Configurable { @Override public Compressor createCompressor() { return null; } @Override public Decompressor createDecompressor() { return null; } @Override public CompressionInputStream createInputStream(InputStream arg0) throws IOException { return null; } @Override public CompressionInputStream createInputStream(InputStream arg0, Decompressor arg1) throws IOException { return null; } @Override public CompressionOutputStream createOutputStream(OutputStream arg0) throws IOException { return null; } @Override public CompressionOutputStream createOutputStream(OutputStream arg0, Compressor arg1) throws IOException { return null; } @Override public Class<? extends Compressor> getCompressorType() { return null; } @Override public Class<? extends Decompressor> getDecompressorType() { return null; } @Override public String getDefaultExtension() { return null; } @Override public Configuration getConf() { return null; } @Override public void setConf(Configuration arg0) { } } }