/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.state;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.concurrent.FutureUtils;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
import org.apache.flink.runtime.util.BlockingCheckpointOutputStream;
import org.apache.flink.util.Preconditions;
import org.junit.Assert;
import org.junit.Test;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class OperatorStateBackendTest {

	private final ClassLoader classLoader = getClass().getClassLoader();
	private final Collection<OperatorStateHandle> emptyStateHandles = Collections.emptyList();

	@Test
	public void testCreateOnAbstractStateBackend() throws Exception {
		// we use the memory state backend as a subclass of the AbstractStateBackend
		final AbstractStateBackend abstractStateBackend = new MemoryStateBackend();
		CloseableRegistry cancelStreamRegistry = new CloseableRegistry();
		final OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
				createMockEnvironment(), "test-operator", emptyStateHandles, cancelStreamRegistry);

		assertNotNull(operatorStateBackend);
		assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty());
		assertTrue(operatorStateBackend.getRegisteredBroadcastStateNames().isEmpty());
	}

	@Test
	public void testRegisterStatesWithoutTypeSerializer() throws Exception {
		// prepare an execution config with a non standard type registered
		final Class<?> registeredType = FutureTask.class;

		// validate the precondition of this test - if this condition fails, we need to pick a different
		// example serializer
		assertFalse(new KryoSerializer<>(File.class, new ExecutionConfig()).getKryo().getDefaultSerializer(registeredType)
				instanceof com.esotericsoftware.kryo.serializers.JavaSerializer);

		final ExecutionConfig cfg = new ExecutionConfig();
		cfg.registerTypeWithKryoSerializer(registeredType, com.esotericsoftware.kryo.serializers.JavaSerializer.class);

		final OperatorStateBackend operatorStateBackend =
			new DefaultOperatorStateBackendBuilder(
				classLoader,
				cfg,
				false,
				emptyStateHandles,
				new CloseableRegistry())
			.build();

		ListStateDescriptor<File> stateDescriptor = new ListStateDescriptor<>("test", File.class);
		ListStateDescriptor<String> stateDescriptor2 = new ListStateDescriptor<>("test2", String.class);

		ListState<File> listState = operatorStateBackend.getListState(stateDescriptor);
		assertNotNull(listState);

		ListState<String> listState2 = operatorStateBackend.getListState(stateDescriptor2);
		assertNotNull(listState2);

		assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());

		// make sure that type registrations are forwarded
		TypeSerializer<?> serializer = ((PartitionableListState<?>) listState).getStateMetaInfo().getPartitionStateSerializer();
		assertTrue(serializer instanceof KryoSerializer);
		assertTrue(((KryoSerializer<?>) serializer).getKryo().getSerializer(registeredType)
				instanceof com.esotericsoftware.kryo.serializers.JavaSerializer);

		Iterator<String> it = listState2.get().iterator();
		assertFalse(it.hasNext());
		listState2.add("kevin");
		listState2.add("sunny");

		it = listState2.get().iterator();
		assertEquals("kevin", it.next());
		assertEquals("sunny", it.next());
		assertFalse(it.hasNext());
	}

	@Test
	public void testRegisterStates() throws Exception {
		final OperatorStateBackend operatorStateBackend =
			new DefaultOperatorStateBackendBuilder(
				classLoader,
				new ExecutionConfig(),
				false,
				emptyStateHandles,
				new CloseableRegistry()).build();

		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
		ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1);
		assertNotNull(listState1);
		assertEquals(1, operatorStateBackend.getRegisteredStateNames().size());
		Iterator<Serializable> it = listState1.get().iterator();
		assertFalse(it.hasNext());
		listState1.add(42);
		listState1.add(4711);

		it = listState1.get().iterator();
		assertEquals(42, it.next());
		assertEquals(4711, it.next());
		assertFalse(it.hasNext());

		ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2);
		assertNotNull(listState2);
		assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
		assertFalse(it.hasNext());
		listState2.add(7);
		listState2.add(13);
		listState2.add(23);

		it = listState2.get().iterator();
		assertEquals(7, it.next());
		assertEquals(13, it.next());
		assertEquals(23, it.next());
		assertFalse(it.hasNext());

		ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3);
		assertNotNull(listState3);
		assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
		assertFalse(it.hasNext());
		listState3.add(17);
		listState3.add(3);
		listState3.add(123);

		it = listState3.get().iterator();
		assertEquals(17, it.next());
		assertEquals(3, it.next());
		assertEquals(123, it.next());
		assertFalse(it.hasNext());

		ListState<Serializable> listState1b = operatorStateBackend.getListState(stateDescriptor1);
		assertNotNull(listState1b);
		listState1b.add(123);
		it = listState1b.get().iterator();
		assertEquals(42, it.next());
		assertEquals(4711, it.next());
		assertEquals(123, it.next());
		assertFalse(it.hasNext());

		it = listState1.get().iterator();
		assertEquals(42, it.next());
		assertEquals(4711, it.next());
		assertEquals(123, it.next());
		assertFalse(it.hasNext());

		it = listState1b.get().iterator();
		assertEquals(42, it.next());
		assertEquals(4711, it.next());
		assertEquals(123, it.next());
		assertFalse(it.hasNext());

		try {
			operatorStateBackend.getUnionListState(stateDescriptor2);
			fail("Did not detect changed mode");
		} catch (IllegalStateException ignored) {

		}

		try {
			operatorStateBackend.getListState(stateDescriptor3);
			fail("Did not detect changed mode");
		} catch (IllegalStateException ignored) {

		}
	}

	@SuppressWarnings("unchecked")
	@Test
	public void testCorrectClassLoaderUsedOnSnapshot() throws Exception {

		AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);

		final Environment env = createMockEnvironment();
		CloseableRegistry cancelStreamRegistry = new CloseableRegistry();
		OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(env, "test-op-name", emptyStateHandles, cancelStreamRegistry);

		AtomicInteger copyCounter = new AtomicInteger(0);
		TypeSerializer<Integer> serializer = new VerifyingIntSerializer(env.getUserClassLoader(), copyCounter);

		// write some state
		ListStateDescriptor<Integer> stateDescriptor = new ListStateDescriptor<>("test", serializer);
		ListState<Integer> listState = operatorStateBackend.getListState(stateDescriptor);

		listState.add(42);

		AtomicInteger keyCopyCounter = new AtomicInteger(0);
		AtomicInteger valueCopyCounter = new AtomicInteger(0);

		TypeSerializer<Integer> keySerializer = new VerifyingIntSerializer(env.getUserClassLoader(), keyCopyCounter);
		TypeSerializer<Integer> valueSerializer = new VerifyingIntSerializer(env.getUserClassLoader(), valueCopyCounter);

		MapStateDescriptor<Integer, Integer> broadcastStateDesc = new MapStateDescriptor<>(
				"test-broadcast", keySerializer, valueSerializer);

		BroadcastState<Integer, Integer> broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc);
		broadcastState.put(1, 2);
		broadcastState.put(3, 4);
		broadcastState.put(5, 6);

		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096);
		RunnableFuture<SnapshotResult<OperatorStateHandle>> runnableFuture =
			operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
		FutureUtils.runIfNotDoneAndGet(runnableFuture);

		// make sure that the copy method has been called
		assertTrue(copyCounter.get() > 0);
		assertTrue(keyCopyCounter.get() > 0);
		assertTrue(valueCopyCounter.get() > 0);
	}

	/**
	 * Int serializer which verifies that the given classloader is set for the copy operation
	 */
	private static final class VerifyingIntSerializer extends TypeSerializer<Integer> {

		private static final long serialVersionUID = -5344563614550163898L;

		private transient ClassLoader classLoader;
		private transient AtomicInteger atomicInteger;

		private VerifyingIntSerializer(ClassLoader classLoader, AtomicInteger atomicInteger) {
			this.classLoader = Preconditions.checkNotNull(classLoader);
			this.atomicInteger = Preconditions.checkNotNull(atomicInteger);
		}

		@Override
		public boolean isImmutableType() {
			// otherwise the copy method won't be called for the deepCopy operation
			return false;
		}

		@Override
		public TypeSerializer<Integer> duplicate() {
			return this;
		}

		@Override
		public Integer createInstance() {
			return 0;
		}

		@Override
		public Integer copy(Integer from) {
			assertEquals(classLoader, Thread.currentThread().getContextClassLoader());
			atomicInteger.incrementAndGet();
			return IntSerializer.INSTANCE.copy(from);
		}

		@Override
		public Integer copy(Integer from, Integer reuse) {
			assertEquals(classLoader, Thread.currentThread().getContextClassLoader());
			atomicInteger.incrementAndGet();
			return IntSerializer.INSTANCE.copy(from, reuse);
		}

		@Override
		public int getLength() {
			return IntSerializer.INSTANCE.getLength();
		}

		@Override
		public void serialize(Integer record, DataOutputView target) throws IOException {
			IntSerializer.INSTANCE.serialize(record, target);
		}

		@Override
		public Integer deserialize(DataInputView source) throws IOException {
			return IntSerializer.INSTANCE.deserialize(source);
		}

		@Override
		public Integer deserialize(Integer reuse, DataInputView source) throws IOException {
			return IntSerializer.INSTANCE.deserialize(reuse, source);
		}

		@Override
		public void copy(DataInputView source, DataOutputView target) throws IOException {
			assertEquals(classLoader, Thread.currentThread().getContextClassLoader());
			atomicInteger.incrementAndGet();
			IntSerializer.INSTANCE.copy(source, target);
		}

		@Override
		public boolean equals(Object obj) {
			return obj instanceof VerifyingIntSerializer;
		}

		@Override
		public int hashCode() {
			return getClass().hashCode();
		}

		@Override
		public TypeSerializerSnapshot<Integer> snapshotConfiguration() {
			return new VerifyingIntSerializerSnapshot();
		}
	}

	@SuppressWarnings("WeakerAccess")
	public static class VerifyingIntSerializerSnapshot extends SimpleTypeSerializerSnapshot<Integer> {
		public VerifyingIntSerializerSnapshot() {
			super(() -> new VerifyingIntSerializer(Thread.currentThread().getContextClassLoader(), new AtomicInteger()));
		}
	}

	@Test
	public void testSnapshotEmpty() throws Exception {
		final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);
		CloseableRegistry cancelStreamRegistry = new CloseableRegistry();

		final OperatorStateBackend operatorStateBackend =
				abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "testOperator", emptyStateHandles, cancelStreamRegistry);

		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096);

		RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot =
				operatorStateBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

		SnapshotResult<OperatorStateHandle> snapshotResult = FutureUtils.runIfNotDoneAndGet(snapshot);
		OperatorStateHandle stateHandle = snapshotResult.getJobManagerOwnedSnapshot();
		assertNull(stateHandle);
	}

	@Test
	public void testSnapshotBroadcastStateWithEmptyOperatorState() throws Exception {
		final AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);

		OperatorStateBackend operatorStateBackend =
			abstractStateBackend.createOperatorStateBackend(
				createMockEnvironment(),
				"testOperator",
				emptyStateHandles,
				new CloseableRegistry());

		final MapStateDescriptor<Integer, Integer> broadcastStateDesc = new MapStateDescriptor<>(
				"test-broadcast", BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);

		final Map<Integer, Integer> expected = new HashMap<>(3);
		expected.put(1, 2);
		expected.put(3, 4);
		expected.put(5, 6);

		final BroadcastState<Integer, Integer> broadcastState = operatorStateBackend.getBroadcastState(broadcastStateDesc);
		broadcastState.putAll(expected);

		final CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(4096);
		OperatorStateHandle stateHandle = null;

		try {
			RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot =
					operatorStateBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

			SnapshotResult<OperatorStateHandle> snapshotResult = FutureUtils.runIfNotDoneAndGet(snapshot);
			stateHandle = snapshotResult.getJobManagerOwnedSnapshot();
			assertNotNull(stateHandle);

			final Map<Integer, Integer> retrieved = new HashMap<>();

			operatorStateBackend = recreateOperatorStateBackend(
				operatorStateBackend,
				abstractStateBackend,
				StateObjectCollection.singleton(stateHandle));
			BroadcastState<Integer, Integer> retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc);
			for (Map.Entry<Integer, Integer> e: retrievedState.entries()) {
				retrieved.put(e.getKey(), e.getValue());
			}
			assertEquals(expected, retrieved);

			// remove an element from both expected and stored state.
			retrievedState.remove(1);
			expected.remove(1);

			snapshot = operatorStateBackend.snapshot(1L, 1L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
			snapshotResult = FutureUtils.runIfNotDoneAndGet(snapshot);

			stateHandle.discardState();
			stateHandle = snapshotResult.getJobManagerOwnedSnapshot();

			retrieved.clear();
			operatorStateBackend = recreateOperatorStateBackend(
				operatorStateBackend,
				abstractStateBackend,
				StateObjectCollection.singleton(stateHandle));
			retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc);
			for (Map.Entry<Integer, Integer> e: retrievedState.immutableEntries()) {
				retrieved.put(e.getKey(), e.getValue());
			}
			assertEquals(expected, retrieved);

			// remove all elements from both expected and stored state.
			retrievedState.clear();
			expected.clear();

			snapshot = operatorStateBackend.snapshot(2L, 2L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
			snapshotResult = FutureUtils.runIfNotDoneAndGet(snapshot);
			if (stateHandle != null) {
				stateHandle.discardState();
			}
			stateHandle = snapshotResult.getJobManagerOwnedSnapshot();

			retrieved.clear();
			operatorStateBackend = recreateOperatorStateBackend(
				operatorStateBackend,
				abstractStateBackend,
				StateObjectCollection.singleton(stateHandle));
			retrievedState = operatorStateBackend.getBroadcastState(broadcastStateDesc);
			for (Map.Entry<Integer, Integer> e: retrievedState.immutableEntries()) {
				retrieved.put(e.getKey(), e.getValue());
			}
			assertTrue(expected.isEmpty());
			assertEquals(expected, retrieved);
			if (stateHandle != null) {
				stateHandle.discardState();
				stateHandle = null;
			}
		} finally {
			operatorStateBackend.close();
			operatorStateBackend.dispose();
			if (stateHandle != null) {
				stateHandle.discardState();
			}
		}
	}

	@Test
	public void testSnapshotRestoreSync() throws Exception {
		AbstractStateBackend abstractStateBackend = new MemoryStateBackend(2 * 4096);

		OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
			createMockEnvironment(),
			"test-op-name",
			emptyStateHandles,
			new CloseableRegistry());
		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());

		MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor1 = new MapStateDescriptor<>("test4", new JavaSerializer<>(), new JavaSerializer<>());
		MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor2 = new MapStateDescriptor<>("test5", new JavaSerializer<>(), new JavaSerializer<>());
		MapStateDescriptor<Serializable, Serializable> broadcastStateDescriptor3 = new MapStateDescriptor<>("test6", new JavaSerializer<>(), new JavaSerializer<>());

		ListState<Serializable> listState1 = operatorStateBackend.getListState(stateDescriptor1);
		ListState<Serializable> listState2 = operatorStateBackend.getListState(stateDescriptor2);
		ListState<Serializable> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3);

		BroadcastState<Serializable, Serializable> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
		BroadcastState<Serializable, Serializable> broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
		BroadcastState<Serializable, Serializable> broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);

		listState1.add(42);
		listState1.add(4711);

		listState2.add(7);
		listState2.add(13);
		listState2.add(23);

		listState3.add(17);
		listState3.add(18);
		listState3.add(19);
		listState3.add(20);

		broadcastState1.put(1, 2);
		broadcastState1.put(2, 5);

		broadcastState2.put(2, 5);

		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(2 * 4096);
		RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot =
			operatorStateBackend.snapshot(1L, 1L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

		SnapshotResult<OperatorStateHandle> snapshotResult = FutureUtils.runIfNotDoneAndGet(snapshot);
		OperatorStateHandle stateHandle = snapshotResult.getJobManagerOwnedSnapshot();

		try {

			operatorStateBackend.close();
			operatorStateBackend.dispose();

			operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
					createMockEnvironment(),
					"testOperator",
				StateObjectCollection.singleton(stateHandle),
				new CloseableRegistry());

			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
			assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size());

			listState1 = operatorStateBackend.getListState(stateDescriptor1);
			listState2 = operatorStateBackend.getListState(stateDescriptor2);
			listState3 = operatorStateBackend.getUnionListState(stateDescriptor3);

			broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
			broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
			broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);

			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
			assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size());

			Iterator<Serializable> it = listState1.get().iterator();
			assertEquals(42, it.next());
			assertEquals(4711, it.next());
			assertFalse(it.hasNext());

			it = listState2.get().iterator();
			assertEquals(7, it.next());
			assertEquals(13, it.next());
			assertEquals(23, it.next());
			assertFalse(it.hasNext());

			it = listState3.get().iterator();
			assertEquals(17, it.next());
			assertEquals(18, it.next());
			assertEquals(19, it.next());
			assertEquals(20, it.next());
			assertFalse(it.hasNext());

			Iterator<Map.Entry<Serializable, Serializable>> bIt = broadcastState1.iterator();
			assertTrue(bIt.hasNext());
			Map.Entry<Serializable, Serializable> entry = bIt.next();
			assertEquals(1, entry.getKey());
			assertEquals(2, entry.getValue());
			assertTrue(bIt.hasNext());
			entry = bIt.next();
			assertEquals(2, entry.getKey());
			assertEquals(5, entry.getValue());
			assertFalse(bIt.hasNext());

			bIt = broadcastState2.iterator();
			assertTrue(bIt.hasNext());
			entry = bIt.next();
			assertEquals(2, entry.getKey());
			assertEquals(5, entry.getValue());
			assertFalse(bIt.hasNext());

			bIt = broadcastState3.iterator();
			assertFalse(bIt.hasNext());

			operatorStateBackend.close();
			operatorStateBackend.dispose();
		} finally {
			stateHandle.discardState();
		}
	}

	@Test
	public void testSnapshotRestoreAsync() throws Exception {
		OperatorStateBackend operatorStateBackend =
			new DefaultOperatorStateBackendBuilder(
				OperatorStateBackendTest.class.getClassLoader(),
				new ExecutionConfig(),
				true,
				emptyStateHandles,
				new CloseableRegistry()).build();

		ListStateDescriptor<MutableType> stateDescriptor1 =
				new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>());
		ListStateDescriptor<MutableType> stateDescriptor2 =
				new ListStateDescriptor<>("test2", new JavaSerializer<MutableType>());
		ListStateDescriptor<MutableType> stateDescriptor3 =
				new ListStateDescriptor<>("test3", new JavaSerializer<MutableType>());

		MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor1 =
				new MapStateDescriptor<>("test4", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
		MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor2 =
				new MapStateDescriptor<>("test5", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());
		MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor3 =
				new MapStateDescriptor<>("test6", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());

		ListState<MutableType> listState1 = operatorStateBackend.getListState(stateDescriptor1);
		ListState<MutableType> listState2 = operatorStateBackend.getListState(stateDescriptor2);
		ListState<MutableType> listState3 = operatorStateBackend.getUnionListState(stateDescriptor3);

		BroadcastState<MutableType, MutableType> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
		BroadcastState<MutableType, MutableType> broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
		BroadcastState<MutableType, MutableType> broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);

		listState1.add(MutableType.of(42));
		listState1.add(MutableType.of(4711));

		listState2.add(MutableType.of(7));
		listState2.add(MutableType.of(13));
		listState2.add(MutableType.of(23));

		listState3.add(MutableType.of(17));
		listState3.add(MutableType.of(18));
		listState3.add(MutableType.of(19));
		listState3.add(MutableType.of(20));

		broadcastState1.put(MutableType.of(1), MutableType.of(2));
		broadcastState1.put(MutableType.of(2), MutableType.of(5));

		broadcastState2.put(MutableType.of(2), MutableType.of(5));

		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);

		OneShotLatch waiterLatch = new OneShotLatch();
		OneShotLatch blockerLatch = new OneShotLatch();

		streamFactory.setWaiterLatch(waiterLatch);
		streamFactory.setBlockerLatch(blockerLatch);

		RunnableFuture<SnapshotResult<OperatorStateHandle>> runnableFuture =
				operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

		ExecutorService executorService = Executors.newFixedThreadPool(1);

		executorService.submit(runnableFuture);

		// wait until the async checkpoint is in the write code, then continue
		waiterLatch.await();

		// do some mutations to the state, to test if our snapshot will NOT reflect them

		listState1.add(MutableType.of(77));

		broadcastState1.put(MutableType.of(32), MutableType.of(97));

		int n = 0;

		for (MutableType mutableType : listState2.get()) {
			if (++n == 2) {
				// allow the write code to continue, so that we could do changes while state is written in parallel.
				blockerLatch.trigger();
			}
			mutableType.setValue(mutableType.getValue() + 10);
		}

		listState3.clear();
		broadcastState2.clear();

		operatorStateBackend.getListState(
				new ListStateDescriptor<>("test4", new JavaSerializer<MutableType>()));

		// run the snapshot
		SnapshotResult<OperatorStateHandle> snapshotResult = runnableFuture.get();
		OperatorStateHandle stateHandle = snapshotResult.getJobManagerOwnedSnapshot();

		try {

			operatorStateBackend.close();
			operatorStateBackend.dispose();

			AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);
			CloseableRegistry cancelStreamRegistry = new CloseableRegistry();

			operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
				createMockEnvironment(),
				"testOperator",
				StateObjectCollection.singleton(stateHandle),
				cancelStreamRegistry);

			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
			assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size());

			listState1 = operatorStateBackend.getListState(stateDescriptor1);
			listState2 = operatorStateBackend.getListState(stateDescriptor2);
			listState3 = operatorStateBackend.getUnionListState(stateDescriptor3);

			broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
			broadcastState2 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor2);
			broadcastState3 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor3);

			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
			assertEquals(3, operatorStateBackend.getRegisteredBroadcastStateNames().size());

			Iterator<MutableType> it = listState1.get().iterator();
			assertEquals(42, it.next().value);
			assertEquals(4711, it.next().value);
			assertFalse(it.hasNext());

			it = listState2.get().iterator();
			assertEquals(7, it.next().value);
			assertEquals(13, it.next().value);
			assertEquals(23, it.next().value);
			assertFalse(it.hasNext());

			it = listState3.get().iterator();
			assertEquals(17, it.next().value);
			assertEquals(18, it.next().value);
			assertEquals(19, it.next().value);
			assertEquals(20, it.next().value);
			assertFalse(it.hasNext());

			Iterator<Map.Entry<MutableType, MutableType>> bIt = broadcastState1.iterator();
			assertTrue(bIt.hasNext());
			Map.Entry<MutableType, MutableType> entry = bIt.next();
			assertEquals(1, entry.getKey().value);
			assertEquals(2, entry.getValue().value);
			assertTrue(bIt.hasNext());
			entry = bIt.next();
			assertEquals(2, entry.getKey().value);
			assertEquals(5, entry.getValue().value);
			assertFalse(bIt.hasNext());

			bIt = broadcastState2.iterator();
			assertTrue(bIt.hasNext());
			entry = bIt.next();
			assertEquals(2, entry.getKey().value);
			assertEquals(5, entry.getValue().value);
			assertFalse(bIt.hasNext());

			bIt = broadcastState3.iterator();
			assertFalse(bIt.hasNext());

			operatorStateBackend.close();
			operatorStateBackend.dispose();
		} finally {
			stateHandle.discardState();
		}

		executorService.shutdown();
	}

	@Test
	public void testSnapshotAsyncClose() throws Exception {
		DefaultOperatorStateBackend operatorStateBackend =
			new DefaultOperatorStateBackendBuilder(
				OperatorStateBackendTest.class.getClassLoader(),
				new ExecutionConfig(),
				true,
				emptyStateHandles,
				new CloseableRegistry()).build();

		ListStateDescriptor<MutableType> stateDescriptor1 =
				new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>());

		ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);

		listState1.add(MutableType.of(42));
		listState1.add(MutableType.of(4711));

		MapStateDescriptor<MutableType, MutableType> broadcastStateDescriptor1 =
				new MapStateDescriptor<>("test4", new JavaSerializer<MutableType>(), new JavaSerializer<MutableType>());

		BroadcastState<MutableType, MutableType> broadcastState1 = operatorStateBackend.getBroadcastState(broadcastStateDescriptor1);
		broadcastState1.put(MutableType.of(1), MutableType.of(2));
		broadcastState1.put(MutableType.of(2), MutableType.of(5));

		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);

		OneShotLatch waiterLatch = new OneShotLatch();
		OneShotLatch blockerLatch = new OneShotLatch();

		streamFactory.setWaiterLatch(waiterLatch);
		streamFactory.setBlockerLatch(blockerLatch);

		RunnableFuture<SnapshotResult<OperatorStateHandle>> runnableFuture =
				operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

		ExecutorService executorService = Executors.newFixedThreadPool(1);

		executorService.submit(runnableFuture);

		// wait until the async checkpoint is in the write code, then continue
		waiterLatch.await();

		operatorStateBackend.close();

		blockerLatch.trigger();

		try {
			runnableFuture.get(60, TimeUnit.SECONDS);
			Assert.fail();
		} catch (CancellationException expected) {
		}
	}

	@Test
	public void testSnapshotAsyncCancel() throws Exception {
		DefaultOperatorStateBackend operatorStateBackend =
			new DefaultOperatorStateBackendBuilder(
				OperatorStateBackendTest.class.getClassLoader(),
				new ExecutionConfig(),
				true,
				emptyStateHandles,
				new CloseableRegistry()).build();

		ListStateDescriptor<MutableType> stateDescriptor1 =
				new ListStateDescriptor<>("test1", new JavaSerializer<MutableType>());

		ListState<MutableType> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);

		listState1.add(MutableType.of(42));
		listState1.add(MutableType.of(4711));

		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);

		OneShotLatch waiterLatch = new OneShotLatch();
		OneShotLatch blockerLatch = new OneShotLatch();

		streamFactory.setWaiterLatch(waiterLatch);
		streamFactory.setBlockerLatch(blockerLatch);

		RunnableFuture<SnapshotResult<OperatorStateHandle>> runnableFuture =
				operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());

		ExecutorService executorService = Executors.newFixedThreadPool(1);

		executorService.submit(runnableFuture);

		// wait until the async checkpoint is in the stream's write code, then continue
		waiterLatch.await();

		// cancel the future, which should close the underlying stream
		runnableFuture.cancel(true);

		for (BlockingCheckpointOutputStream stream : streamFactory.getAllCreatedStreams()) {
			Assert.assertTrue(stream.isClosed());
		}

		// we allow the stream under test to proceed
		blockerLatch.trigger();

		try {
			runnableFuture.get(60, TimeUnit.SECONDS);
			Assert.fail();
		} catch (CancellationException ignore) {
		}
	}

	static final class MutableType implements Serializable {

		private static final long serialVersionUID = 1L;

		private int value;

		public MutableType() {
			this(0);
		}

		public MutableType(int value) {
			this.value = value;
		}

		public int getValue() {
			return value;
		}

		public void setValue(int value) {
			this.value = value;
		}

		@Override
		public boolean equals(Object o) {

			if (this == o) {
				return true;
			}

			if (o == null || getClass() != o.getClass()) {
				return false;
			}

			MutableType that = (MutableType) o;

			return value == that.value;
		}

		@Override
		public int hashCode() {
			return value;
		}

		static MutableType of(int value) {
			return new MutableType(value);
		}
	}

	// ------------------------------------------------------------------------
	//  utilities
	// ------------------------------------------------------------------------

	private static Environment createMockEnvironment() {
		Environment env = mock(Environment.class);
		when(env.getExecutionConfig()).thenReturn(new ExecutionConfig());
		when(env.getUserClassLoader()).thenReturn(OperatorStateBackendTest.class.getClassLoader());
		return env;
	}

	private static OperatorStateBackend recreateOperatorStateBackend(
		OperatorStateBackend oldOperatorStateBackend,
		AbstractStateBackend abstractStateBackend,
		Collection<OperatorStateHandle> toRestore
	) throws Exception {
		oldOperatorStateBackend.close();
		oldOperatorStateBackend.dispose();
		return abstractStateBackend.createOperatorStateBackend(
			createMockEnvironment(),
			"testOperator",
			toRestore,
			new CloseableRegistry());
	}
}