/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.io.network.api.writer;

import org.apache.flink.api.common.JobID;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderAndConsumerTest;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.MockResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionTest;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.util.DeserializationUtils;
import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
import org.apache.flink.runtime.operators.shipping.OutputEmitter;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.taskmanager.ConsumableNotifyingResultPartitionWriterDecorator;
import org.apache.flink.runtime.taskmanager.NoOpTaskActions;
import org.apache.flink.testutils.serialization.types.SerializationTestType;
import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory;
import org.apache.flink.testutils.serialization.types.Util;
import org.apache.flink.types.IntValue;
import org.apache.flink.util.XORShiftRandom;

import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;

import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.spy;

/**
 * Tests for the {@link RecordWriter}.
 */
public class RecordWriterTest {

	private final boolean isBroadcastWriter;

	public RecordWriterTest() {
		this(false);
	}

	RecordWriterTest(boolean isBroadcastWriter) {
		this.isBroadcastWriter = isBroadcastWriter;
	}

	@Rule
	public TemporaryFolder tempFolder = new TemporaryFolder();

	// ---------------------------------------------------------------------------------------------
	// Resource release tests
	// ---------------------------------------------------------------------------------------------

	/**
	 * Tests a fix for FLINK-2089.
	 *
	 * @see <a href="https://issues.apache.org/jira/browse/FLINK-2089">FLINK-2089</a>
	 */
	@Test
	public void testClearBuffersAfterInterruptDuringBlockingBufferRequest() throws Exception {
		ExecutorService executor = null;

		try {
			executor = Executors.newSingleThreadExecutor();

			TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(1);

			KeepingPartitionWriter partitionWriter = new KeepingPartitionWriter(bufferProvider);

			final RecordWriter<IntValue> recordWriter = createRecordWriter(partitionWriter);

			CountDownLatch waitLock = new CountDownLatch(1);
			Future<?> result = executor.submit(new Callable<Void>() {
				@Override
				public Void call() throws Exception {
					IntValue val = new IntValue(0);

					try {
						recordWriter.emit(val);
						recordWriter.flushAll();
						waitLock.countDown();
						recordWriter.emit(val);
					}
					catch (InterruptedException e) {
						recordWriter.clearBuffers();
					}

					return null;
				}
			});

			waitLock.await();

			// Interrupt the Thread.
			//
			// The second emit call requests a new buffer and blocks the thread.
			// When interrupting the thread at this point, clearing the buffers
			// should not recycle any buffer.
			result.cancel(true);

			recordWriter.clearBuffers();

			// Verify that the written out buffer has only been recycled once
			// (by the partition writer), so no buffer recycled.
			assertEquals(0, bufferProvider.getNumberOfAvailableBuffers());

			partitionWriter.close();
			assertEquals(1, bufferProvider.getNumberOfAvailableBuffers());
		}
		finally {
			if (executor != null) {
				executor.shutdown();
			}
		}
	}

	@Test
	public void testSerializerClearedAfterClearBuffers() throws Exception {
		ResultPartitionWriter partitionWriter =
			spy(new RecyclingPartitionWriter(new TestPooledBufferProvider(1, 16)));

		RecordWriter<IntValue> recordWriter = createRecordWriter(partitionWriter);

		// Fill a buffer, but don't write it out.
		recordWriter.emit(new IntValue(0));

		// Clear all buffers.
		recordWriter.clearBuffers();

		// This should not throw an Exception iff the serializer state
		// has been cleared as expected.
		recordWriter.flushAll();
	}

	/**
	 * Tests broadcasting events when no records have been emitted yet.
	 */
	@Test
	public void testBroadcastEventNoRecords() throws Exception {
		int numberOfChannels = 4;
		int bufferSize = 32;

		@SuppressWarnings("unchecked")
		Queue<BufferConsumer>[] queues = new Queue[numberOfChannels];
		for (int i = 0; i < numberOfChannels; i++) {
			queues[i] = new ArrayDeque<>();
		}

		TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);

		ResultPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
		RecordWriter<ByteArrayIO> writer = createRecordWriter(partitionWriter);
		CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 919192L, Integer.MAX_VALUE + 18828228L, CheckpointOptions.forCheckpointWithDefaultLocation());

		// No records emitted yet, broadcast should not request a buffer
		writer.broadcastEvent(barrier);

		assertEquals(0, bufferProvider.getNumberOfCreatedBuffers());

		for (int i = 0; i < numberOfChannels; i++) {
			assertEquals(1, queues[i].size());
			BufferOrEvent boe = parseBuffer(queues[i].remove(), i);
			assertTrue(boe.isEvent());
			assertEquals(barrier, boe.getEvent());
			assertEquals(0, queues[i].size());
		}
	}

	/**
	 * Tests broadcasting events when records have been emitted. The emitted
	 * records cover all three {@link SerializationResult} types.
	 */
	@Test
	public void testBroadcastEventMixedRecords() throws Exception {
		Random rand = new XORShiftRandom();
		int numberOfChannels = 4;
		int bufferSize = 32;
		int lenBytes = 4; // serialized length

		@SuppressWarnings("unchecked")
		Queue<BufferConsumer>[] queues = new Queue[numberOfChannels];
		for (int i = 0; i < numberOfChannels; i++) {
			queues[i] = new ArrayDeque<>();
		}

		TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);

		ResultPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
		RecordWriter<ByteArrayIO> writer = createRecordWriter(partitionWriter);
		CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 1292L, Integer.MAX_VALUE + 199L, CheckpointOptions.forCheckpointWithDefaultLocation());

		// Emit records on some channels first (requesting buffers), then
		// broadcast the event. The record buffers should be emitted first, then
		// the event. After the event, no new buffer should be requested.

		// (i) Smaller than the buffer size
		byte[] bytes = new byte[bufferSize / 2];
		rand.nextBytes(bytes);

		writer.emit(new ByteArrayIO(bytes));

		// (ii) Larger than the buffer size
		bytes = new byte[bufferSize + 1];
		rand.nextBytes(bytes);

		writer.emit(new ByteArrayIO(bytes));

		// (iii) Exactly the buffer size
		bytes = new byte[bufferSize - lenBytes];
		rand.nextBytes(bytes);

		writer.emit(new ByteArrayIO(bytes));

		// (iv) Broadcast the event
		writer.broadcastEvent(barrier);

		if (isBroadcastWriter) {
			assertEquals(3, bufferProvider.getNumberOfCreatedBuffers());

			for (int i = 0; i < numberOfChannels; i++) {
				assertEquals(4, queues[i].size()); // 3 buffer + 1 event

				for (int j = 0; j < 3; j++) {
					assertTrue(parseBuffer(queues[i].remove(), 0).isBuffer());
				}

				BufferOrEvent boe = parseBuffer(queues[i].remove(), i);
				assertTrue(boe.isEvent());
				assertEquals(barrier, boe.getEvent());
			}
		} else {
			assertEquals(4, bufferProvider.getNumberOfCreatedBuffers());

			assertEquals(2, queues[0].size()); // 1 buffer + 1 event
			assertTrue(parseBuffer(queues[0].remove(), 0).isBuffer());
			assertEquals(3, queues[1].size()); // 2 buffers + 1 event
			assertTrue(parseBuffer(queues[1].remove(), 1).isBuffer());
			assertTrue(parseBuffer(queues[1].remove(), 1).isBuffer());
			assertEquals(2, queues[2].size()); // 1 buffer + 1 event
			assertTrue(parseBuffer(queues[2].remove(), 2).isBuffer());
			assertEquals(1, queues[3].size()); // 0 buffers + 1 event

			// every queue's last element should be the event
			for (int i = 0; i < numberOfChannels; i++) {
				BufferOrEvent boe = parseBuffer(queues[i].remove(), i);
				assertTrue(boe.isEvent());
				assertEquals(barrier, boe.getEvent());
			}
		}
	}

	/**
	 * Tests that event buffers are properly recycled when broadcasting events
	 * to multiple channels.
	 */
	@Test
	public void testBroadcastEventBufferReferenceCounting() throws Exception {

		@SuppressWarnings("unchecked")
		ArrayDeque<BufferConsumer>[] queues = new ArrayDeque[] { new ArrayDeque(), new ArrayDeque() };

		ResultPartitionWriter partition =
			new CollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE));
		RecordWriter<?> writer = createRecordWriter(partition);

		writer.broadcastEvent(EndOfPartitionEvent.INSTANCE);

		// Verify added to all queues
		assertEquals(1, queues[0].size());
		assertEquals(1, queues[1].size());

		// get references to buffer consumers (copies from the original event buffer consumer)
		BufferConsumer bufferConsumer1 = queues[0].getFirst();
		BufferConsumer bufferConsumer2 = queues[1].getFirst();

		// process all collected events (recycles the buffer)
		for (int i = 0; i < queues.length; i++) {
			assertTrue(parseBuffer(queues[i].remove(), i).isEvent());
		}

		assertTrue(bufferConsumer1.isRecycled());
		assertTrue(bufferConsumer2.isRecycled());
	}

	/**
	 * Tests that broadcasted events' buffers are independent (in their (reader) indices) once they
	 * are put into the queue for Netty when broadcasting events to multiple channels.
	 */
	@Test
	public void testBroadcastEventBufferIndependence() throws Exception {
		verifyBroadcastBufferOrEventIndependence(true);
	}

	/**
	 * Tests that broadcasted records' buffers are independent (in their (reader) indices) once they
	 * are put into the queue for Netty when broadcasting events to multiple channels.
	 */
	@Test
	public void testBroadcastEmitBufferIndependence() throws Exception {
		verifyBroadcastBufferOrEventIndependence(false);
	}

	/**
	 * Tests that records are broadcast via {@link RecordWriter#broadcastEmit(IOReadableWritable)}.
	 */
	@Test
	public void testBroadcastEmitRecord() throws Exception {
		final int numberOfChannels = 4;
		final int bufferSize = 32;
		final int numValues = 8;
		final int serializationLength = 4;

		@SuppressWarnings("unchecked")
		final Queue<BufferConsumer>[] queues = new Queue[numberOfChannels];
		for (int i = 0; i < numberOfChannels; i++) {
			queues[i] = new ArrayDeque<>();
		}

		final TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);
		final ResultPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
		final RecordWriter<SerializationTestType> writer = createRecordWriter(partitionWriter);
		final RecordDeserializer<SerializationTestType> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<>(
			new String[]{ tempFolder.getRoot().getAbsolutePath() });

		final ArrayDeque<SerializationTestType> serializedRecords = new ArrayDeque<>();
		final Iterable<SerializationTestType> records = Util.randomRecords(numValues, SerializationTestTypeFactory.INT);
		for (SerializationTestType record : records) {
			serializedRecords.add(record);
			writer.broadcastEmit(record);
		}

		final int numRequiredBuffers = numValues / (bufferSize / (4 + serializationLength));
		if (isBroadcastWriter) {
			assertEquals(numRequiredBuffers, bufferProvider.getNumberOfCreatedBuffers());
		} else {
			assertEquals(numRequiredBuffers * numberOfChannels, bufferProvider.getNumberOfCreatedBuffers());
		}

		for (int i = 0; i < numberOfChannels; i++) {
			assertEquals(numRequiredBuffers, queues[i].size());
			verifyDeserializationResults(queues[i], deserializer, serializedRecords.clone(), numRequiredBuffers, numValues);
		}
	}

	/**
	 * Tests that the RecordWriter is available iif the respective LocalBufferPool has at-least one available buffer.
	 */
	@Test
	public void testIsAvailableOrNot() throws Exception {
		// setup
		final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 2);
		final BufferPool localPool = globalPool.createBufferPool(1, 1, null, 1, Integer.MAX_VALUE);
		final ResultPartitionWriter resultPartition = new ResultPartitionBuilder()
			.setBufferPoolFactory(p -> localPool)
			.build();
		resultPartition.setup();
		final ResultPartitionWriter partitionWrapper = new ConsumableNotifyingResultPartitionWriterDecorator(
			new NoOpTaskActions(),
			new JobID(),
			resultPartition,
			new NoOpResultPartitionConsumableNotifier());
		final RecordWriter recordWriter = createRecordWriter(partitionWrapper);

		try {
			// record writer is available because of initial available global pool
			assertTrue(recordWriter.getAvailableFuture().isDone());

			// request one buffer from the local pool to make it unavailable afterwards
			final BufferBuilder bufferBuilder = resultPartition.getBufferBuilder(0);
			assertNotNull(bufferBuilder);
			assertFalse(recordWriter.getAvailableFuture().isDone());

			// recycle the buffer to make the local pool available again
			final Buffer buffer = BufferBuilderTestUtils.buildSingleBuffer(bufferBuilder);
			buffer.recycleBuffer();
			assertTrue(recordWriter.getAvailableFuture().isDone());
			assertEquals(recordWriter.AVAILABLE, recordWriter.getAvailableFuture());

		} finally {
			localPool.lazyDestroy();
			globalPool.destroy();
		}
	}

	@Test
	public void testEmitRecordWithPartitionStateRecovery() throws Exception {
		final int totalBuffers = 10; // enough for both states and normal records
		final int totalStates = 2;
		final int[] states = {1, 2, 3, 4};
		final int[] records = {5, 6, 7, 8};
		final int bufferSize = states.length * Integer.BYTES;

		final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, bufferSize, 1);
		final ChannelStateReader stateReader = new ResultPartitionTest.FiniteChannelStateReader(totalStates, states);
		final ResultPartition partition = new ResultPartitionBuilder()
			.setNetworkBufferPool(globalPool)
			.build();
		final RecordWriter<IntValue> recordWriter = new RecordWriterBuilder<IntValue>().build(partition);

		try {
			partition.setup();
			partition.readRecoveredState(stateReader);

			for (int record: records) {
				// the record length 4 is also written into buffer for every emit
				recordWriter.broadcastEmit(new IntValue(record));
			}

			// every buffer can contain 2 int records with 2 int length(4)
			final int[][] expectedRecordsInBuffer = {{4, 5, 4, 6}, {4, 7, 4, 8}};

			for (ResultSubpartition subpartition : partition.getAllPartitions()) {
				// create the view to consume all the buffers with states and records
				final ResultSubpartitionView view = new PipelinedSubpartitionView(
					(PipelinedSubpartition) subpartition,
					new NoOpBufferAvailablityListener());

				int numConsumedBuffers = 0;
				ResultSubpartition.BufferAndBacklog bufferAndBacklog;
				while ((bufferAndBacklog = view.getNextBuffer()) != null) {
					Buffer buffer = bufferAndBacklog.buffer();
					int[] expected = numConsumedBuffers < totalStates ? states : expectedRecordsInBuffer[numConsumedBuffers - totalStates];
					BufferBuilderAndConsumerTest.assertContent(
						buffer,
						partition.getBufferPool()
							.getSubpartitionBufferRecyclers()[subpartition.getSubPartitionIndex()],
						expected);

					buffer.recycleBuffer();
					numConsumedBuffers++;
				}

				assertEquals(totalStates + expectedRecordsInBuffer.length, numConsumedBuffers);
			}
		} finally {
			// cleanup
			globalPool.destroyAllBufferPools();
			globalPool.destroy();
		}
	}

	@Test
	public void testIdleTime() throws IOException, InterruptedException {
		// setup
		final NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 2);
		final BufferPool localPool = globalPool.createBufferPool(1, 1, null, 1, Integer.MAX_VALUE);
		final ResultPartitionWriter resultPartition = new ResultPartitionBuilder()
			.setBufferPoolFactory(p -> localPool)
			.build();
		resultPartition.setup();
		final ResultPartitionWriter partitionWrapper = new ConsumableNotifyingResultPartitionWriterDecorator(
			new NoOpTaskActions(),
			new JobID(),
			resultPartition,
			new NoOpResultPartitionConsumableNotifier());
		final RecordWriter recordWriter = createRecordWriter(partitionWrapper);
		BufferBuilder builder = recordWriter.requestNewBufferBuilder(0);
		BufferBuilderTestUtils.fillBufferBuilder(builder, 1).finish();
		ResultSubpartitionView readView = resultPartition.getSubpartition(0).createReadView(new NoOpBufferAvailablityListener());
		Buffer buffer = readView.getNextBuffer().buffer();

		// idle time is zero when there is buffer available.
		assertEquals(0, recordWriter.getIdleTimeMsPerSecond().getCount());

		CountDownLatch syncLock = new CountDownLatch(1);
		AtomicReference<BufferBuilder> asyncRequestResult = new AtomicReference<>();
		final Thread requestThread = new Thread(new Runnable() {
			@Override
			public void run() {
				try {
					// notify that the request thread start to run.
					syncLock.countDown();
					// wait for buffer.
					asyncRequestResult.set(recordWriter.requestNewBufferBuilder(0));
				} catch (Exception e) {
				}
			}
		});
		requestThread.start();

		// wait until request thread start to run.
		syncLock.await();

		Thread.sleep(10);

		//recycle the buffer
		buffer.recycleBuffer();
		requestThread.join();

		assertThat(recordWriter.getIdleTimeMsPerSecond().getCount(), Matchers.greaterThan(0L));
		assertNotNull(asyncRequestResult.get());
	}

	private void verifyBroadcastBufferOrEventIndependence(boolean broadcastEvent) throws Exception {
		@SuppressWarnings("unchecked")
		ArrayDeque<BufferConsumer>[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()};

		ResultPartitionWriter partition =
			new CollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE));
		RecordWriter<IntValue> writer = createRecordWriter(partition);

		if (broadcastEvent) {
			writer.broadcastEvent(EndOfPartitionEvent.INSTANCE);
		} else {
			writer.broadcastEmit(new IntValue(0));
		}

		// verify added to all queues
		assertEquals(1, queues[0].size());
		assertEquals(1, queues[1].size());

		// these two buffers may share the memory but not the indices!
		Buffer buffer1 = buildSingleBuffer(queues[0].remove());
		Buffer buffer2 = buildSingleBuffer(queues[1].remove());
		assertEquals(0, buffer1.getReaderIndex());
		assertEquals(0, buffer2.getReaderIndex());
		buffer1.setReaderIndex(1);
		assertEquals("Buffer 2 shares the same reader index as buffer 1", 0, buffer2.getReaderIndex());
	}

	protected void verifyDeserializationResults(
			Queue<BufferConsumer> queue,
			RecordDeserializer<SerializationTestType> deserializer,
			ArrayDeque<SerializationTestType> expectedRecords,
			int numRequiredBuffers,
			int numValues) throws Exception {
		int assertRecords = 0;
		for (int j = 0; j < numRequiredBuffers; j++) {
			Buffer buffer = buildSingleBuffer(queue.remove());
			deserializer.setNextBuffer(buffer);

			assertRecords += DeserializationUtils.deserializeRecords(expectedRecords, deserializer);
		}
		Assert.assertEquals(numValues, assertRecords);
	}

	/**
	 * Creates the {@link RecordWriter} instance based on whether it is a broadcast writer.
	 */
	private RecordWriter createRecordWriter(ResultPartitionWriter writer) {
		if (isBroadcastWriter) {
			return new RecordWriterBuilder()
				.setChannelSelector(new OutputEmitter(ShipStrategyType.BROADCAST, 0))
				.build(writer);
		} else {
			return new RecordWriterBuilder().build(writer);
		}
	}

	// ---------------------------------------------------------------------------------------------
	// Helpers
	// ---------------------------------------------------------------------------------------------

	/**
	 * Partition writer that collects the added buffers/events in multiple queue.
	 */
	static class CollectingPartitionWriter extends MockResultPartitionWriter {
		private final Queue<BufferConsumer>[] queues;
		private final BufferProvider bufferProvider;

		/**
		 * Create the partition writer.
		 *
		 * @param queues one queue per outgoing channel
		 * @param bufferProvider buffer provider
		 */
		CollectingPartitionWriter(Queue<BufferConsumer>[] queues, BufferProvider bufferProvider) {
			this.queues = queues;
			this.bufferProvider = bufferProvider;
		}

		@Override
		public int getNumberOfSubpartitions() {
			return queues.length;
		}

		@Override
		public BufferBuilder getBufferBuilder(int targetChannel) throws IOException, InterruptedException {
			return bufferProvider.requestBufferBuilderBlocking(targetChannel);
		}

		@Override
		public BufferBuilder tryGetBufferBuilder(int targetChannel) throws IOException {
			return bufferProvider.requestBufferBuilder(targetChannel);
		}

		@Override
		public boolean addBufferConsumer(BufferConsumer buffer, int targetChannel, boolean isPriorityEvent) {
			return queues[targetChannel].add(buffer);
		}
	}

	static BufferOrEvent parseBuffer(BufferConsumer bufferConsumer, int targetChannel) throws IOException {
		Buffer buffer = buildSingleBuffer(bufferConsumer);
		if (buffer.isBuffer()) {
			return new BufferOrEvent(buffer, new InputChannelInfo(0, targetChannel));
		} else {
			// is event:
			AbstractEvent event = EventSerializer.fromBuffer(buffer, RecordWriterTest.class.getClassLoader());
			buffer.recycleBuffer(); // the buffer is not needed anymore
			return new BufferOrEvent(event, new InputChannelInfo(0, targetChannel));
		}
	}

	/**
	 * Partition writer that recycles all received buffers and does no further processing.
	 */
	private static class RecyclingPartitionWriter extends MockResultPartitionWriter {
		private final BufferProvider bufferProvider;

		private RecyclingPartitionWriter(BufferProvider bufferProvider) {
			this.bufferProvider = bufferProvider;
		}

		@Override
		public BufferBuilder getBufferBuilder(int targetChannel) throws IOException, InterruptedException {
			return bufferProvider.requestBufferBuilderBlocking(targetChannel);
		}

		@Override
		public BufferBuilder tryGetBufferBuilder(int targetChannel) throws IOException {
			return bufferProvider.requestBufferBuilder(targetChannel);
		}
	}

	static class KeepingPartitionWriter extends MockResultPartitionWriter {
		private final BufferProvider bufferProvider;
		private Map<Integer, List<BufferConsumer>> produced = new HashMap<>();

		KeepingPartitionWriter(BufferProvider bufferProvider) {
			this.bufferProvider = bufferProvider;
		}

		@Override
		public BufferBuilder getBufferBuilder(int targetChannel) throws IOException, InterruptedException {
			return bufferProvider.requestBufferBuilderBlocking(targetChannel);
		}

		@Override
		public BufferBuilder tryGetBufferBuilder(int targetChannel) throws IOException {
			return bufferProvider.requestBufferBuilder(targetChannel);
		}

		@Override
		public boolean addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel, boolean isPriorityEvent) {
			// keep the buffer occupied.
			produced.putIfAbsent(targetChannel, new ArrayList<>());
			produced.get(targetChannel).add(bufferConsumer);
			return true;
		}

		public List<BufferConsumer> getAddedBufferConsumers(int subpartitionIndex) {
			return produced.get(subpartitionIndex);
		}

		@Override
		public void close() {
			for (List<BufferConsumer> bufferConsumers : produced.values()) {
				for (BufferConsumer bufferConsumer : bufferConsumers) {
					bufferConsumer.close();
				}
			}
			produced.clear();
		}
	}

	private static class ByteArrayIO implements IOReadableWritable {

		private final byte[] bytes;

		public ByteArrayIO(byte[] bytes) {
			this.bytes = bytes;
		}

		@Override
		public void write(DataOutputView out) throws IOException {
			out.write(bytes);
		}

		@Override
		public void read(DataInputView in) throws IOException {
			in.readFully(bytes);
		}
	}

	private static class TrackingBufferRecycler implements BufferRecycler {
		private final ArrayList<MemorySegment> recycledMemorySegments = new ArrayList<>();

		@Override
		public synchronized void recycle(MemorySegment memorySegment) {
			recycledMemorySegments.add(memorySegment);
		}

		public synchronized List<MemorySegment> getRecycledMemorySegments() {
			return recycledMemorySegments;
		}
	}
}