/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.prestosql.operator.exchange; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.prestosql.SequencePageBuilder; import io.prestosql.execution.Lifespan; import io.prestosql.operator.InterpretedHashGenerator; import io.prestosql.operator.PageAssertions; import io.prestosql.operator.PipelineExecutionStrategy; import io.prestosql.operator.exchange.LocalExchange.LocalExchangeFactory; import io.prestosql.operator.exchange.LocalExchange.LocalExchangeSinkFactory; import io.prestosql.operator.exchange.LocalExchange.LocalExchangeSinkFactoryId; import io.prestosql.spi.Page; import io.prestosql.spi.type.Type; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; import java.util.function.Consumer; import static io.airlift.testing.Assertions.assertContains; import static io.prestosql.operator.PipelineExecutionStrategy.GROUPED_EXECUTION; import static io.prestosql.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION; import static io.prestosql.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public class TestLocalExchange { private static final List<Type> TYPES = ImmutableList.of(BIGINT); private static final DataSize RETAINED_PAGE_SIZE = DataSize.ofBytes(createPage(42).getRetainedSizeInBytes()); private static final DataSize LOCAL_EXCHANGE_MAX_BUFFERED_BYTES = DataSize.of(32, DataSize.Unit.MEGABYTE); @DataProvider public static Object[][] executionStrategy() { return new Object[][] {{UNGROUPED_EXECUTION}, {GROUPED_EXECUTION}}; } @Test(dataProvider = "executionStrategy") public void testGatherSingleWriter(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( SINGLE_DISTRIBUTION, 8, TYPES, ImmutableList.of(), Optional.empty(), executionStrategy, DataSize.ofBytes(retainedSizeOfPages(99))); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 1); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSource source = exchange.getSource(0); assertSource(source, 0); LocalExchangeSink sink = sinkFactory.createSink(); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); assertSinkCanWrite(sink); assertSource(source, 0); // add the first page which should cause the reader to unblock ListenableFuture<?> readFuture = source.waitForReading(); assertFalse(readFuture.isDone()); sink.addPage(createPage(0)); assertTrue(readFuture.isDone()); assertExchangeTotalBufferedBytes(exchange, 1); assertSource(source, 1); sink.addPage(createPage(1)); assertSource(source, 2); assertExchangeTotalBufferedBytes(exchange, 2); assertRemovePage(source, createPage(0)); assertSource(source, 1); assertExchangeTotalBufferedBytes(exchange, 1); assertRemovePage(source, createPage(1)); assertSource(source, 0); assertExchangeTotalBufferedBytes(exchange, 0); sink.addPage(createPage(2)); sink.addPage(createPage(3)); assertSource(source, 2); assertExchangeTotalBufferedBytes(exchange, 2); sink.finish(); assertSinkFinished(sink); assertSource(source, 2); assertRemovePage(source, createPage(2)); assertSource(source, 1); assertSinkFinished(sink); assertExchangeTotalBufferedBytes(exchange, 1); assertRemovePage(source, createPage(3)); assertSourceFinished(source); assertExchangeTotalBufferedBytes(exchange, 0); }); } @Test(dataProvider = "executionStrategy") public void testBroadcast(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_BROADCAST_DISTRIBUTION, 2, TYPES, ImmutableList.of(), Optional.empty(), executionStrategy, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sinkA = sinkFactory.createSink(); assertSinkCanWrite(sinkA); LocalExchangeSink sinkB = sinkFactory.createSink(); assertSinkCanWrite(sinkB); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); sinkA.addPage(createPage(0)); assertSource(sourceA, 1); assertSource(sourceB, 1); assertExchangeTotalBufferedBytes(exchange, 1); sinkA.addPage(createPage(0)); assertSource(sourceA, 2); assertSource(sourceB, 2); assertExchangeTotalBufferedBytes(exchange, 2); assertRemovePage(sourceA, createPage(0)); assertSource(sourceA, 1); assertSource(sourceB, 2); assertExchangeTotalBufferedBytes(exchange, 2); assertRemovePage(sourceA, createPage(0)); assertSource(sourceA, 0); assertSource(sourceB, 2); assertExchangeTotalBufferedBytes(exchange, 2); sinkA.finish(); assertSinkFinished(sinkA); assertExchangeTotalBufferedBytes(exchange, 2); sinkB.addPage(createPage(0)); assertSource(sourceA, 1); assertSource(sourceB, 3); assertExchangeTotalBufferedBytes(exchange, 3); sinkB.finish(); assertSinkFinished(sinkB); assertSource(sourceA, 1); assertSource(sourceB, 3); assertExchangeTotalBufferedBytes(exchange, 3); assertRemovePage(sourceA, createPage(0)); assertSourceFinished(sourceA); assertSource(sourceB, 3); assertExchangeTotalBufferedBytes(exchange, 3); assertRemovePage(sourceB, createPage(0)); assertRemovePage(sourceB, createPage(0)); assertSourceFinished(sourceA); assertSource(sourceB, 1); assertExchangeTotalBufferedBytes(exchange, 1); assertRemovePage(sourceB, createPage(0)); assertSourceFinished(sourceA); assertSourceFinished(sourceB); assertExchangeTotalBufferedBytes(exchange, 0); }); } @Test(dataProvider = "executionStrategy") public void testRandom(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_ARBITRARY_DISTRIBUTION, 2, TYPES, ImmutableList.of(), Optional.empty(), executionStrategy, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sink = sinkFactory.createSink(); assertSinkCanWrite(sink); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); for (int i = 0; i < 100; i++) { Page page = createPage(0); sink.addPage(page); assertExchangeTotalBufferedBytes(exchange, i + 1); LocalExchangeBufferInfo bufferInfoA = sourceA.getBufferInfo(); LocalExchangeBufferInfo bufferInfoB = sourceB.getBufferInfo(); assertEquals(bufferInfoA.getBufferedBytes() + bufferInfoB.getBufferedBytes(), retainedSizeOfPages(i + 1)); assertEquals(bufferInfoA.getBufferedPages() + bufferInfoB.getBufferedPages(), i + 1); } // we should get ~50 pages per source, but we should get at least some pages in each buffer assertTrue(sourceA.getBufferInfo().getBufferedPages() > 0); assertTrue(sourceB.getBufferInfo().getBufferedPages() > 0); assertExchangeTotalBufferedBytes(exchange, 100); }); } @Test(dataProvider = "executionStrategy") public void testPassthrough(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_PASSTHROUGH_DISTRIBUTION, 2, TYPES, ImmutableList.of(), Optional.empty(), executionStrategy, DataSize.ofBytes(retainedSizeOfPages(1))); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sinkA = sinkFactory.createSink(); LocalExchangeSink sinkB = sinkFactory.createSink(); assertSinkCanWrite(sinkA); assertSinkCanWrite(sinkB); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); sinkA.addPage(createPage(0)); assertSource(sourceA, 1); assertSource(sourceB, 0); assertSinkWriteBlocked(sinkA); assertSinkCanWrite(sinkB); sinkB.addPage(createPage(1)); assertSource(sourceA, 1); assertSource(sourceB, 1); assertSinkWriteBlocked(sinkA); assertExchangeTotalBufferedBytes(exchange, 2); assertRemovePage(sourceA, createPage(0)); assertSource(sourceA, 0); assertSinkCanWrite(sinkA); assertSinkWriteBlocked(sinkB); assertExchangeTotalBufferedBytes(exchange, 1); sinkA.finish(); assertSinkFinished(sinkA); assertSource(sourceB, 1); sourceA.finish(); sourceB.finish(); assertRemovePage(sourceB, createPage(1)); assertSourceFinished(sourceA); assertSourceFinished(sourceB); assertSinkFinished(sinkB); assertExchangeTotalBufferedBytes(exchange, 0); }); } @Test(dataProvider = "executionStrategy") public void testPartition(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_HASH_DISTRIBUTION, 2, TYPES, ImmutableList.of(0), Optional.empty(), executionStrategy, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sink = sinkFactory.createSink(); assertSinkCanWrite(sink); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); sink.addPage(createPage(0)); assertSource(sourceA, 1); assertSource(sourceB, 1); assertTrue(exchange.getBufferedBytes() >= retainedSizeOfPages(1)); sink.addPage(createPage(0)); assertSource(sourceA, 2); assertSource(sourceB, 2); assertTrue(exchange.getBufferedBytes() >= retainedSizeOfPages(2)); assertPartitionedRemovePage(sourceA, 0, 2); assertSource(sourceA, 1); assertSource(sourceB, 2); assertPartitionedRemovePage(sourceA, 0, 2); assertSource(sourceA, 0); assertSource(sourceB, 2); sink.finish(); assertSinkFinished(sink); assertSourceFinished(sourceA); assertSource(sourceB, 2); assertPartitionedRemovePage(sourceB, 1, 2); assertSourceFinished(sourceA); assertSource(sourceB, 1); assertPartitionedRemovePage(sourceB, 1, 2); assertSourceFinished(sourceA); assertSourceFinished(sourceB); assertExchangeTotalBufferedBytes(exchange, 0); }); } @Test(dataProvider = "executionStrategy") public void writeUnblockWhenAllReadersFinish(PipelineExecutionStrategy executionStrategy) { ImmutableList<Type> types = ImmutableList.of(BIGINT); LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_BROADCAST_DISTRIBUTION, 2, types, ImmutableList.of(), Optional.empty(), executionStrategy, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sinkA = sinkFactory.createSink(); assertSinkCanWrite(sinkA); LocalExchangeSink sinkB = sinkFactory.createSink(); assertSinkCanWrite(sinkB); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); sourceA.finish(); assertSourceFinished(sourceA); assertSinkCanWrite(sinkA); assertSinkCanWrite(sinkB); sourceB.finish(); assertSourceFinished(sourceB); assertSinkFinished(sinkA); assertSinkFinished(sinkB); }); } @Test(dataProvider = "executionStrategy") public void writeUnblockWhenAllReadersFinishAndPagesConsumed(PipelineExecutionStrategy executionStrategy) { LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory( FIXED_BROADCAST_DISTRIBUTION, 2, TYPES, ImmutableList.of(), Optional.empty(), executionStrategy, DataSize.ofBytes(1)); LocalExchangeSinkFactoryId localExchangeSinkFactoryId = localExchangeFactory.newSinkFactoryId(); localExchangeFactory.noMoreSinkFactories(); run(localExchangeFactory, executionStrategy, exchange -> { assertEquals(exchange.getBufferCount(), 2); assertExchangeTotalBufferedBytes(exchange, 0); LocalExchangeSinkFactory sinkFactory = exchange.getSinkFactory(localExchangeSinkFactoryId); LocalExchangeSink sinkA = sinkFactory.createSink(); assertSinkCanWrite(sinkA); LocalExchangeSink sinkB = sinkFactory.createSink(); assertSinkCanWrite(sinkB); sinkFactory.close(); sinkFactory.noMoreSinkFactories(); LocalExchangeSource sourceA = exchange.getSource(0); assertSource(sourceA, 0); LocalExchangeSource sourceB = exchange.getSource(1); assertSource(sourceB, 0); sinkA.addPage(createPage(0)); ListenableFuture<?> sinkAFuture = assertSinkWriteBlocked(sinkA); ListenableFuture<?> sinkBFuture = assertSinkWriteBlocked(sinkB); assertSource(sourceA, 1); assertSource(sourceB, 1); assertExchangeTotalBufferedBytes(exchange, 1); sourceA.finish(); assertSource(sourceA, 1); assertRemovePage(sourceA, createPage(0)); assertSourceFinished(sourceA); assertExchangeTotalBufferedBytes(exchange, 1); assertSource(sourceB, 1); assertSinkWriteBlocked(sinkA); assertSinkWriteBlocked(sinkB); sourceB.finish(); assertSource(sourceB, 1); assertRemovePage(sourceB, createPage(0)); assertSourceFinished(sourceB); assertExchangeTotalBufferedBytes(exchange, 0); assertTrue(sinkAFuture.isDone()); assertTrue(sinkBFuture.isDone()); assertSinkFinished(sinkA); assertSinkFinished(sinkB); }); } @Test public void testMismatchedExecutionStrategy() { // If sink/source didn't create a matching set of exchanges, operators will block forever, // waiting for the other half that will never show up. // The most common reason of mismatch is when one of sink/source created the wrong kind of local exchange. // In such case, we want to fail loudly. LocalExchangeFactory ungroupedLocalExchangeFactory = new LocalExchangeFactory( FIXED_HASH_DISTRIBUTION, 2, TYPES, ImmutableList.of(0), Optional.empty(), UNGROUPED_EXECUTION, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); try { ungroupedLocalExchangeFactory.getLocalExchange(Lifespan.driverGroup(3)); fail("expected failure"); } catch (IllegalArgumentException e) { assertContains(e.getMessage(), "Driver-group exchange cannot be created."); } LocalExchangeFactory groupedLocalExchangeFactory = new LocalExchangeFactory( FIXED_HASH_DISTRIBUTION, 2, TYPES, ImmutableList.of(0), Optional.empty(), GROUPED_EXECUTION, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES); try { groupedLocalExchangeFactory.getLocalExchange(Lifespan.taskWide()); fail("expected failure"); } catch (IllegalArgumentException e) { assertContains(e.getMessage(), "Task-wide exchange cannot be created."); } } private void run(LocalExchangeFactory localExchangeFactory, PipelineExecutionStrategy pipelineExecutionStrategy, Consumer<LocalExchange> test) { switch (pipelineExecutionStrategy) { case UNGROUPED_EXECUTION: test.accept(localExchangeFactory.getLocalExchange(Lifespan.taskWide())); break; case GROUPED_EXECUTION: test.accept(localExchangeFactory.getLocalExchange(Lifespan.driverGroup(1))); test.accept(localExchangeFactory.getLocalExchange(Lifespan.driverGroup(12))); test.accept(localExchangeFactory.getLocalExchange(Lifespan.driverGroup(23))); break; default: throw new IllegalArgumentException("Unknown pipelineExecutionStrategy"); } } private static void assertSource(LocalExchangeSource source, int pageCount) { LocalExchangeBufferInfo bufferInfo = source.getBufferInfo(); assertEquals(bufferInfo.getBufferedPages(), pageCount); assertFalse(source.isFinished()); if (pageCount == 0) { assertFalse(source.waitForReading().isDone()); assertNull(source.removePage()); assertFalse(source.waitForReading().isDone()); assertFalse(source.isFinished()); assertEquals(bufferInfo.getBufferedBytes(), 0); } else { assertTrue(source.waitForReading().isDone()); assertTrue(bufferInfo.getBufferedBytes() > 0); } } private static void assertSourceFinished(LocalExchangeSource source) { assertTrue(source.isFinished()); LocalExchangeBufferInfo bufferInfo = source.getBufferInfo(); assertEquals(bufferInfo.getBufferedPages(), 0); assertEquals(bufferInfo.getBufferedBytes(), 0); assertTrue(source.waitForReading().isDone()); assertNull(source.removePage()); assertTrue(source.waitForReading().isDone()); assertTrue(source.isFinished()); } private static void assertRemovePage(LocalExchangeSource source, Page expectedPage) { assertTrue(source.waitForReading().isDone()); Page actualPage = source.removePage(); assertNotNull(actualPage); assertEquals(actualPage.getChannelCount(), expectedPage.getChannelCount()); PageAssertions.assertPageEquals(TYPES, actualPage, expectedPage); } private static void assertPartitionedRemovePage(LocalExchangeSource source, int partition, int partitionCount) { assertTrue(source.waitForReading().isDone()); Page page = source.removePage(); assertNotNull(page); LocalPartitionGenerator partitionGenerator = new LocalPartitionGenerator(new InterpretedHashGenerator(TYPES, new int[] {0}), partitionCount); for (int position = 0; position < page.getPositionCount(); position++) { assertEquals(partitionGenerator.getPartition(page, position), partition); } } private static void assertSinkCanWrite(LocalExchangeSink sink) { assertFalse(sink.isFinished()); assertTrue(sink.waitForWriting().isDone()); } private static ListenableFuture<?> assertSinkWriteBlocked(LocalExchangeSink sink) { assertFalse(sink.isFinished()); ListenableFuture<?> writeFuture = sink.waitForWriting(); assertFalse(writeFuture.isDone()); return writeFuture; } private static void assertSinkFinished(LocalExchangeSink sink) { assertTrue(sink.isFinished()); assertTrue(sink.waitForWriting().isDone()); // this will be ignored sink.addPage(createPage(0)); assertTrue(sink.isFinished()); assertTrue(sink.waitForWriting().isDone()); } private static void assertExchangeTotalBufferedBytes(LocalExchange exchange, int pageCount) { assertEquals(exchange.getBufferedBytes(), retainedSizeOfPages(pageCount)); } private static Page createPage(int i) { return SequencePageBuilder.createSequencePage(TYPES, 100, i); } public static long retainedSizeOfPages(int count) { return RETAINED_PAGE_SIZE.toBytes() * count; } }