package org.apache.tez.runtime.library.common; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.Lists; import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.LocalDirAllocator; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.BoundedByteArrayOutputStream; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.RawComparator; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.io.serializer.SerializationFactory; import org.apache.hadoop.io.serializer.Serializer; import org.apache.hadoop.util.Progress; import org.apache.hadoop.util.Progressable; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.counters.GenericCounter; import org.apache.tez.common.counters.TezCounter; import org.apache.tez.common.counters.TezCounters; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.runtime.api.InputContext; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.common.comparator.TezBytesComparator; import org.apache.tez.runtime.library.common.serializer.TezBytesWritableSerialization; import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryReader; import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryWriter; import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.MergeManager; import org.apache.tez.runtime.library.common.sort.impl.IFile; import org.apache.tez.runtime.library.common.sort.impl.TezMerger; import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.mockito.internal.util.collections.Sets; import java.io.IOException; import java.math.BigInteger; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.TreeMap; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; /** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ @RunWith(Parameterized.class) public class TestValuesIterator { private static final Logger LOG = LoggerFactory.getLogger(TestValuesIterator.class); static final String TEZ_BYTES_SERIALIZATION = TezBytesWritableSerialization.class.getName(); enum TestWithComparator { LONG, INT, BYTES, TEZ_BYTES, TEXT, CUSTOM } Configuration conf; FileSystem fs; static final Random rnd = new Random(); final Class keyClass; final Class valClass; final RawComparator comparator; final RawComparator correctComparator; final boolean expectedTestResult; int mergeFactor; //For storing original data final ListMultimap<Writable, Writable> originalData; TezRawKeyValueIterator rawKeyValueIterator; Path baseDir; Path tmpDir; Path[] streamPaths; //merge stream paths /** * Constructor * * @param serializationClassName serialization class to be used * @param key key class name * @param val value class name * @param comparator to be used * @param correctComparator (real comparator to be used for correct results) * @param testResult expected result * @throws IOException */ public TestValuesIterator(String serializationClassName, Class key, Class val, TestWithComparator comparator, TestWithComparator correctComparator, boolean testResult) throws IOException { this.keyClass = key; this.valClass = val; this.comparator = getComparator(comparator); this.correctComparator = (correctComparator == null) ? this.comparator : getComparator(correctComparator); this.expectedTestResult = testResult; originalData = LinkedListMultimap.create(); setupConf(serializationClassName); } private void setupConf(String serializationClassName) throws IOException { mergeFactor = 2; conf = new Configuration(); conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_IO_SORT_FACTOR, mergeFactor); if (serializationClassName != null) { conf.set(CommonConfigurationKeys.IO_SERIALIZATIONS_KEY, serializationClassName + "," + conf.get(CommonConfigurationKeys.IO_SERIALIZATIONS_KEY)); } baseDir = new Path(".", this.getClass().getName()); String localDirs = baseDir.toString(); conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, localDirs); fs = FileSystem.getLocal(conf); } @Before public void setup() throws Exception { fs.mkdirs(baseDir); tmpDir = new Path(baseDir, "tmp"); } @After public void cleanup() throws Exception { fs.delete(baseDir, true); originalData.clear(); } @Test(timeout = 20000) public void testIteratorWithInMemoryReader() throws IOException, InterruptedException { ValuesIterator iterator = createIterator(true); verifyIteratorData(iterator); } @Test(timeout = 20000) public void testIteratorWithIFileReader() throws IOException, InterruptedException { ValuesIterator iterator = createIterator(false); verifyIteratorData(iterator); } @Test(timeout = 20000) public void testCountedIteratorWithInmemoryReader() throws IOException, InterruptedException { verifyCountedIteratorReader(true); } @Test(timeout = 20000) public void testCountedIteratorWithIFileReader() throws IOException, InterruptedException { verifyCountedIteratorReader(false); } private void verifyCountedIteratorReader(boolean inMemory) throws IOException, InterruptedException { TezCounter keyCounter = new GenericCounter("inputKeyCounter", "y3"); TezCounter tupleCounter = new GenericCounter("inputValuesCounter", "y4"); ValuesIterator iterator = createCountedIterator(inMemory, keyCounter, tupleCounter); List<Integer> sequence = verifyIteratorData(iterator); if (expectedTestResult) { assertEquals((long) sequence.size(), keyCounter.getValue()); long rows = 0; for (Integer i : sequence) { rows += i.longValue(); } assertEquals(rows, tupleCounter.getValue()); } } @Test(timeout = 20000) public void testIteratorWithIFileReaderEmptyPartitions() throws IOException, InterruptedException { ValuesIterator iterator = createEmptyIterator(false); assertTrue(iterator.moveToNext() == false); iterator = createEmptyIterator(true); assertTrue(iterator.moveToNext() == false); } private void getNextFromFinishedIterator(ValuesIterator iterator) { try { boolean hasNext = iterator.moveToNext(); fail(); } catch(IOException e) { assertTrue(e.getMessage().contains("Please check if you are invoking moveToNext()")); } } @SuppressWarnings("unchecked") private ValuesIterator createEmptyIterator(boolean inMemory) throws IOException, InterruptedException { if (!inMemory) { streamPaths = new Path[0]; //This will return EmptyIterator rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, null, false, -1, 1024, streamPaths, false, mergeFactor, tmpDir, comparator, new ProgressReporter(), null, null, null, null); } else { List<TezMerger.Segment> segments = Lists.newLinkedList(); //This will return EmptyIterator rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, segments, mergeFactor, tmpDir, comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress()); } return new ValuesIterator(rawKeyValueIterator, comparator, keyClass, valClass, conf, (TezCounter) new GenericCounter("inputKeyCounter", "y3"), (TezCounter) new GenericCounter("inputValueCounter", "y4")); } /** * Tests whether data in valuesIterator matches with sorted input data set. * * Returns a list of value counts for each key. * * @param valuesIterator * @return List * @throws IOException */ @SuppressWarnings("unchecked") private List<Integer> verifyIteratorData( ValuesIterator valuesIterator) throws IOException { boolean result = true; ArrayList<Integer> sequence = new ArrayList<Integer>(); //sort original data based on comparator ListMultimap<Writable, Writable> sortedMap = new ImmutableListMultimap.Builder<Writable, Writable>() .orderKeysBy(this.correctComparator).putAll (originalData).build(); Set<Map.Entry<Writable, Writable>> oriKeySet = Sets.newSet(); oriKeySet.addAll(sortedMap.entries()); //Iterate through sorted data and valuesIterator for verification for (Map.Entry<Writable, Writable> entry : oriKeySet) { assertTrue(valuesIterator.moveToNext()); Writable oriKey = entry.getKey(); //Verify if the key and the original key are same if (!oriKey.equals((Writable) valuesIterator.getKey())) { result = false; break; } int valueCount = 0; //Verify values Iterator<Writable> vItr = valuesIterator.getValues().iterator(); for (Writable val : sortedMap.get(oriKey)) { assertTrue(vItr.hasNext()); //Verify if the values are same if (!val.equals((Writable) vItr.next())) { result = false; break; } valueCount++; } sequence.add(valueCount); assertTrue("At least 1 value per key", valueCount > 0); } if (expectedTestResult) { assertTrue(result); assertFalse(valuesIterator.moveToNext()); getNextFromFinishedIterator(valuesIterator); } else { while(valuesIterator.moveToNext()) { //iterate through all keys } getNextFromFinishedIterator(valuesIterator); assertFalse(result); } return sequence; } /** * Create sample data (in memory / disk based), merge them and return ValuesIterator * * @param inMemory * @return ValuesIterator * @throws IOException */ @SuppressWarnings("unchecked") private ValuesIterator createIterator(boolean inMemory) throws IOException, InterruptedException { if (!inMemory) { streamPaths = createFiles(); //Merge all files to get KeyValueIterator rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, null, false, -1, 1024, streamPaths, false, mergeFactor, tmpDir, comparator, new ProgressReporter(), null, null, null, null); } else { List<TezMerger.Segment> segments = createInMemStreams(); rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, segments, mergeFactor, tmpDir, comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress()); } return new ValuesIterator(rawKeyValueIterator, comparator, keyClass, valClass, conf, (TezCounter) new GenericCounter("inputKeyCounter", "y3"), (TezCounter) new GenericCounter("inputValueCounter", "y4")); } /** * Create sample data (in memory), with an attached counter and return ValuesIterator * * @param inMemory * @param keyCounter * @param tupleCounter * @return ValuesIterator * @throws IOException */ @SuppressWarnings("unchecked") private ValuesIterator createCountedIterator(boolean inMemory, TezCounter keyCounter, TezCounter tupleCounter) throws IOException, InterruptedException { if (!inMemory) { streamPaths = createFiles(); //Merge all files to get KeyValueIterator rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, null, false, -1, 1024, streamPaths, false, mergeFactor, tmpDir, comparator, new ProgressReporter(), null, null, null, null); } else { List<TezMerger.Segment> segments = createInMemStreams(); rawKeyValueIterator = TezMerger.merge(conf, fs, keyClass, valClass, segments, mergeFactor, tmpDir, comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress()); } return new ValuesIterator(rawKeyValueIterator, comparator, keyClass, valClass, conf, keyCounter, tupleCounter); } @Parameterized.Parameters(name = "test[{0}, {1}, {2}, {3} {4} {5} {6}]") public static Collection<Object[]> getParameters() { Collection<Object[]> parameters = new ArrayList<Object[]>(); //parameters for constructor parameters.add(new Object[] { null, Text.class, Text.class, TestWithComparator.TEXT, null, true }); parameters.add(new Object[] { null, LongWritable.class, Text.class, TestWithComparator.LONG, null, true }); parameters.add(new Object[] { null, IntWritable.class, Text.class, TestWithComparator.INT, null, true }); parameters.add(new Object[] { null, BytesWritable.class, BytesWritable.class, TestWithComparator.BYTES, null, true }); parameters.add(new Object[] { TEZ_BYTES_SERIALIZATION, BytesWritable.class, BytesWritable.class, TestWithComparator.TEZ_BYTES, null, true }); parameters.add(new Object[] { TEZ_BYTES_SERIALIZATION, BytesWritable.class, LongWritable.class, TestWithComparator.TEZ_BYTES, null, true }); parameters.add(new Object[] { TEZ_BYTES_SERIALIZATION, CustomKey.class, LongWritable.class, TestWithComparator.TEZ_BYTES, null, true }); //negative tests parameters.add(new Object[] { TEZ_BYTES_SERIALIZATION, BytesWritable.class, BytesWritable.class, TestWithComparator.BYTES, TestWithComparator.TEZ_BYTES, false }); parameters.add(new Object[] { TEZ_BYTES_SERIALIZATION, CustomKey.class, LongWritable.class, TestWithComparator.CUSTOM, TestWithComparator.TEZ_BYTES, false }); return parameters; } private RawComparator getComparator(TestWithComparator comparator) { switch (comparator) { case LONG: return new LongWritable.Comparator(); case INT: return new IntWritable.Comparator(); case BYTES: return new BytesWritable.Comparator(); case TEZ_BYTES: return new TezBytesComparator(); case TEXT: return new Text.Comparator(); case CUSTOM: return new CustomKey.Comparator(); default: return null; } } private Path[] createFiles() throws IOException { int numberOfStreams = Math.max(2, rnd.nextInt(10)); mergeFactor = Math.max(mergeFactor, numberOfStreams); LOG.info("No of streams : " + numberOfStreams); Path[] paths = new Path[numberOfStreams]; for (int i = 0; i < numberOfStreams; i++) { paths[i] = new Path(baseDir, "ifile_" + i + ".out"); FSDataOutputStream out = fs.create(paths[i]); //write data with RLE IFile.Writer writer = new IFile.Writer(conf, out, keyClass, valClass, null, null, null, true); Map<Writable, Writable> data = createData(); for (Map.Entry<Writable, Writable> entry : data.entrySet()) { writer.append(entry.getKey(), entry.getValue()); originalData.put(entry.getKey(), entry.getValue()); if (rnd.nextInt() % 2 == 0) { for (int j = 0; j < rnd.nextInt(100); j++) { //add some duplicate keys writer.append(entry.getKey(), entry.getValue()); originalData.put(entry.getKey(), entry.getValue()); } } } LOG.info("Wrote " + data.size() + " in " + paths[i]); data.clear(); writer.close(); out.close(); } return paths; } /** * create inmemory segments * * @return * @throws IOException */ @SuppressWarnings("unchecked") public List<TezMerger.Segment> createInMemStreams() throws IOException { int numberOfStreams = Math.max(2, rnd.nextInt(10)); LOG.info("No of streams : " + numberOfStreams); SerializationFactory serializationFactory = new SerializationFactory(conf); Serializer keySerializer = serializationFactory.getSerializer(keyClass); Serializer valueSerializer = serializationFactory.getSerializer(valClass); LocalDirAllocator localDirAllocator = new LocalDirAllocator(TezRuntimeFrameworkConfigs.LOCAL_DIRS); InputContext context = createTezInputContext(); MergeManager mergeManager = new MergeManager(conf, fs, localDirAllocator, context, null, null, null, null, null, 1024 * 1024 * 10, null, false, -1); DataOutputBuffer keyBuf = new DataOutputBuffer(); DataOutputBuffer valBuf = new DataOutputBuffer(); DataInputBuffer keyIn = new DataInputBuffer(); DataInputBuffer valIn = new DataInputBuffer(); keySerializer.open(keyBuf); valueSerializer.open(valBuf); List<TezMerger.Segment> segments = new LinkedList<TezMerger.Segment>(); for (int i = 0; i < numberOfStreams; i++) { BoundedByteArrayOutputStream bout = new BoundedByteArrayOutputStream(1024 * 1024); InMemoryWriter writer = new InMemoryWriter(bout); Map<Writable, Writable> data = createData(); //write data for (Map.Entry<Writable, Writable> entry : data.entrySet()) { keySerializer.serialize(entry.getKey()); valueSerializer.serialize(entry.getValue()); keyIn.reset(keyBuf.getData(), 0, keyBuf.getLength()); valIn.reset(valBuf.getData(), 0, valBuf.getLength()); writer.append(keyIn, valIn); originalData.put(entry.getKey(), entry.getValue()); keyBuf.reset(); valBuf.reset(); keyIn.reset(); valIn.reset(); } IFile.Reader reader = new InMemoryReader(mergeManager, null, bout.getBuffer(), 0, bout.getBuffer().length); segments.add(new TezMerger.Segment(reader, null)); data.clear(); writer.close(); } return segments; } private InputContext createTezInputContext() { TezCounters counters = new TezCounters(); InputContext inputContext = mock(InputContext.class); doReturn(1024 * 1024 * 100l).when(inputContext).getTotalMemoryAvailableToTask(); doReturn(counters).when(inputContext).getCounters(); doReturn(1).when(inputContext).getInputIndex(); doReturn("srcVertex").when(inputContext).getSourceVertexName(); doReturn(1).when(inputContext).getTaskVertexIndex(); doReturn(UserPayload.create(ByteBuffer.wrap(new byte[1024]))).when(inputContext).getUserPayload(); return inputContext; } @SuppressWarnings("unchecked") private Map<Writable, Writable> createData() { Map<Writable, Writable> map = new TreeMap<Writable, Writable>(comparator); for (int j = 0; j < Math.max(10, rnd.nextInt(50)); j++) { Writable key = createData(keyClass); Writable value = createData(valClass); map.put(key, value); //sortedDataMap.put(key, value); } return map; } private Writable createData(Class c) { if (c.getName().equalsIgnoreCase(BytesWritable.class.getName())) { return new BytesWritable(new BigInteger(256, rnd).toString().getBytes()); } else if (c.getName().equalsIgnoreCase(IntWritable.class.getName())) { return new IntWritable(rnd.nextInt()); } else if (c.getName().equalsIgnoreCase(LongWritable.class.getName())) { return new LongWritable(rnd.nextLong()); } else if (c.getName().equalsIgnoreCase(CustomKey.class.getName())) { String rndStr = new BigInteger(256, rnd).toString() + "_" + new BigInteger(256, rnd).toString(); return new CustomKey(rndStr.getBytes(), rndStr.hashCode()); } else if (c.getName().equalsIgnoreCase(Text.class.getName())) { String rndStr = new BigInteger(256, rnd).toString() + "_" + new BigInteger(256, rnd).toString(); return new Text(rndStr); } else { throw new IllegalArgumentException("Illegal argument : " + c.getName()); } } private static class ProgressReporter implements Progressable { @Override public void progress() { //no impl } } //Custom key and comparator public static class CustomKey extends BytesWritable { private static final int LENGTH_BYTES = 4; private int hashCode; public CustomKey() { } public CustomKey(byte[] data, int hashCode) { super(data); this.hashCode = hashCode; } @Override public int hashCode() { return hashCode; } public static class Comparator extends WritableComparator { public Comparator() { super(CustomKey.class); } /** * Compare the buffers in serialized form. */ @Override public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { return compareBytes(b1, s1 + LENGTH_BYTES, l1 - LENGTH_BYTES, b2, s2 + LENGTH_BYTES, l2 - LENGTH_BYTES); } } static { WritableComparator.define(CustomKey.class, new Comparator()); } } }