/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.typeserializerupgrade;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateBackendLoader;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.StateMigrationException;
import org.apache.flink.util.TestLogger;

import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * Tests the state migration behaviour when the underlying POJO type changes
 * and one tries to recover from old state.
 */
@RunWith(Parameterized.class)
public class PojoSerializerUpgradeTest extends TestLogger {

	@Parameterized.Parameters(name = "StateBackend: {0}")
	public static Collection<String> parameters () {
		return Arrays.asList(
				StateBackendLoader.MEMORY_STATE_BACKEND_NAME,
				StateBackendLoader.FS_STATE_BACKEND_NAME,
				StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME);
	}

	@ClassRule
	public static TemporaryFolder temporaryFolder = new TemporaryFolder();

	private StateBackend stateBackend;

	public PojoSerializerUpgradeTest(String backendType) throws IOException, DynamicCodeLoadingException {
		Configuration config = new Configuration();
		config.setString(CheckpointingOptions.STATE_BACKEND, backendType);
		config.setString(CheckpointingOptions.CHECKPOINTS_DIRECTORY, temporaryFolder.newFolder().toURI().toString());
		stateBackend = StateBackendLoader.loadStateBackendFromConfig(config, Thread.currentThread().getContextClassLoader(), null);
	}

	private static final String POJO_NAME = "Pojo";

	private static final String SOURCE_A =
		"import java.util.Objects;" +
		"public class Pojo { " +
		"private long a; " +
		"private String b; " +
		"public long getA() { return a;} " +
		"public void setA(long value) { a = value; }" +
		"public String getB() { return b; }" +
		"public void setB(String value) { b = value; }" +
		"@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}" +
		"@Override public int hashCode() { return Objects.hash(a, b); } " +
		"@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

	// changed order of fields which should be recoverable
	private static final String SOURCE_B =
		"import java.util.Objects;" +
		"public class Pojo { " +
		"private String b; " +
		"private long a; " +
		"public long getA() { return a;} " +
		"public void setA(long value) { a = value; }" +
		"public String getB() { return b; }" +
		"public void setB(String value) { b = value; }" +
		"@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}" +
		"@Override public int hashCode() { return Objects.hash(a, b); } " +
		"@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

	// changed type of a field which should not be recoverable
	private static final String SOURCE_C =
		"import java.util.Objects;" +
		"public class Pojo { " +
		"private double a; " +
		"private String b; " +
		"public double getA() { return a;} " +
		"public void setA(double value) { a = value; }" +
		"public String getB() { return b; }" +
		"public void setB(String value) { b = value; }" +
		"@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}" +
		"@Override public int hashCode() { return Objects.hash(a, b); } " +
		"@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

	// additional field which should not be recoverable
	private static final String SOURCE_D =
		"import java.util.Objects;" +
		"public class Pojo { " +
		"private long a; " +
		"private String b; " +
		"private double c; " +
		"public long getA() { return a;} " +
		"public void setA(long value) { a = value; }" +
		"public String getB() { return b; }" +
		"public void setB(String value) { b = value; }" +
		"public double getC() { return c; } " +
		"public void setC(double value) { c = value; }" +
		"@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b) && c == other.c;} else { return false; }}" +
		"@Override public int hashCode() { return Objects.hash(a, b, c); } " +
		"@Override public String toString() {return \"(\" + a + \", \" + b + \", \" + c + \")\";}}";

	// missing field which should not be recoverable
	private static final String SOURCE_E =
		"import java.util.Objects;" +
		"public class Pojo { " +
		"private long a; " +
		"public long getA() { return a;} " +
		"public void setA(long value) { a = value; }" +
		"@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a;} else { return false; }}" +
		"@Override public int hashCode() { return Objects.hash(a); } " +
		"@Override public String toString() {return \"(\" + a + \")\";}}";

	/**
	 * We should be able to handle a changed field order of a POJO as keyed state.
	 */
	@Test
	public void testChangedFieldOrderWithKeyedState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, true);
	}

	/**
	 * We should be able to handle a changed field order of a POJO as operator state.
	 */
	@Test
	public void testChangedFieldOrderWithOperatorState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, false);
	}

	/**
	 * Changing field types of a POJO as keyed state should require a state migration.
	 */
	@Test
	public void testChangedFieldTypesWithKeyedState() throws Exception {
		try {
			testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, true);
			fail("Expected a state migration exception.");
		} catch (Exception e) {
			if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
				// StateMigrationException expected
			} else {
				throw e;
			}
		}
	}

	/**
	 * Changing field types of a POJO as operator state should require a state migration.
	 */
	@Test
	public void testChangedFieldTypesWithOperatorState() throws Exception {
		try {
			testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, false);
			fail("Expected a state migration exception.");
		} catch (Exception e) {
			if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
				// StateMigrationException expected
			} else {
				throw e;
			}
		}
	}

	/**
	 * Adding fields to a POJO as keyed state should succeed.
	 */
	@Test
	public void testAdditionalFieldWithKeyedState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, true);
	}

	/**
	 * Adding fields to a POJO as operator state should succeed.
	 */
	@Test
	public void testAdditionalFieldWithOperatorState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, false);
	}

	/**
	 * Removing fields from a POJO as keyed state should succeed.
	 */
	@Test
	public void testMissingFieldWithKeyedState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, true);
	}

	/**
	 * Removing fields from a POJO as operator state should succeed.
	 */
	@Test
	public void testMissingFieldWithOperatorState() throws Exception {
		testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, false);
	}

	private void testPojoSerializerUpgrade(String classSourceA, String classSourceB, boolean hasBField, boolean isKeyedState) throws Exception {
		final Configuration taskConfiguration = new Configuration();
		final ExecutionConfig executionConfig = new ExecutionConfig();
		final KeySelector<Long, Long> keySelector = new IdentityKeySelector<>();
		final Collection<Long> inputs = Arrays.asList(1L, 2L, 45L, 67L, 1337L);

		// run the program with classSourceA
		File rootPath = temporaryFolder.newFolder();
		File sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceA);
		compileClass(sourceFile);

		final ClassLoader classLoader = URLClassLoader.newInstance(
			new URL[]{rootPath.toURI().toURL()},
			Thread.currentThread().getContextClassLoader());

		OperatorSubtaskState stateHandles = runOperator(
			taskConfiguration,
			executionConfig,
			new StreamMap<>(new StatefulMapper(isKeyedState, false, hasBField)),
			keySelector,
			isKeyedState,
			stateBackend,
			classLoader,
			null,
			inputs);

		// run the program with classSourceB
		rootPath = temporaryFolder.newFolder();

		sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceB);
		compileClass(sourceFile);

		final ClassLoader classLoaderB = URLClassLoader.newInstance(
			new URL[]{rootPath.toURI().toURL()},
			Thread.currentThread().getContextClassLoader());

		runOperator(
			taskConfiguration,
			executionConfig,
			new StreamMap<>(new StatefulMapper(isKeyedState, true, hasBField)),
			keySelector,
			isKeyedState,
			stateBackend,
			classLoaderB,
			stateHandles,
			inputs);
	}

	private OperatorSubtaskState runOperator(
			Configuration taskConfiguration,
			ExecutionConfig executionConfig,
			OneInputStreamOperator<Long, Long> operator,
			KeySelector<Long, Long> keySelector,
			boolean isKeyedState,
			StateBackend stateBackend,
			ClassLoader classLoader,
			OperatorSubtaskState operatorSubtaskState,
			Iterable<Long> input) throws Exception {

		try (final MockEnvironment environment =
				new MockEnvironmentBuilder()
					.setTaskName("test task")
					.setMemorySize(32 * 1024)
					.setInputSplitProvider(new MockInputSplitProvider())
					.setBufferSize(256)
					.setTaskConfiguration(taskConfiguration)
					.setExecutionConfig(executionConfig)
					.setMaxParallelism(16)
					.setUserCodeClassLoader(classLoader)
					.build()) {

			OneInputStreamOperatorTestHarness<Long, Long> harness = null;
			try {
				if (isKeyedState) {
					harness = new KeyedOneInputStreamOperatorTestHarness<>(
						operator,
						keySelector,
						BasicTypeInfo.LONG_TYPE_INFO,
						environment);
				} else {
					harness = new OneInputStreamOperatorTestHarness<>(operator, LongSerializer.INSTANCE, environment);
				}

				harness.setStateBackend(stateBackend);

				harness.setup();
				harness.initializeState(operatorSubtaskState);
				harness.open();

				long timestamp = 0L;

				for (Long value : input) {
					harness.processElement(value, timestamp++);
				}

				long checkpointId = 1L;
				long checkpointTimestamp = timestamp + 1L;

				return harness.snapshot(checkpointId, checkpointTimestamp);
			} finally {
				IOUtils.closeQuietly(harness);
			}
		}
	}

	private static File writeSourceFile(File root, String name, String source) throws IOException {
		File sourceFile = new File(root, name);

		sourceFile.getParentFile().mkdirs();

		try (FileWriter writer = new FileWriter(sourceFile)) {
			writer.write(source);
		}

		return sourceFile;
	}

	private static int compileClass(File sourceFile) {
		JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
		return compiler.run(null, null, null, "-proc:none", sourceFile.getPath());
	}

	private static final class StatefulMapper extends RichMapFunction<Long, Long> implements CheckpointedFunction {

		private static final long serialVersionUID = -520490739059396832L;

		private final boolean keyed;
		private final boolean verify;
		private final boolean hasBField;

		// keyed states
		private transient ValueState<Object> keyedValueState;
		private transient ListState<Object> keyedListState;
		private transient ReducingState<Object> keyedReducingState;

		// operator states
		private transient ListState<Object> partitionableListState;
		private transient ListState<Object> unionListState;

		private transient Class<?> pojoClass;
		private transient Field fieldA;
		private transient Field fieldB;

		StatefulMapper(boolean keyed, boolean verify, boolean hasBField) {
			this.keyed = keyed;
			this.verify = verify;
			this.hasBField = hasBField;
		}

		@Override
		public Long map(Long value) throws Exception {
			Object pojo = pojoClass.newInstance();

			fieldA.set(pojo, value);

			if (hasBField) {
				fieldB.set(pojo, value + "");
			}

			if (verify) {
				if (keyed) {
					assertEquals(pojo, keyedValueState.value());

					Iterator<Object> listIterator = keyedListState.get().iterator();

					boolean elementFound = false;

					while (listIterator.hasNext()) {
						elementFound |= pojo.equals(listIterator.next());
					}

					assertTrue(elementFound);

					assertEquals(pojo, keyedReducingState.get());
				} else {
					boolean elementFound = false;
					Iterator<Object> listIterator = partitionableListState.get().iterator();
					while (listIterator.hasNext()) {
						elementFound |= pojo.equals(listIterator.next());
					}
					assertTrue(elementFound);

					elementFound = false;
					listIterator = unionListState.get().iterator();
					while (listIterator.hasNext()) {
						elementFound |= pojo.equals(listIterator.next());
					}
					assertTrue(elementFound);
				}
			} else {
				if (keyed) {
					keyedValueState.update(pojo);
					keyedListState.add(pojo);
					keyedReducingState.add(pojo);
				} else {
					partitionableListState.add(pojo);
					unionListState.add(pojo);
				}
			}

			return value;
		}

		@Override
		public void snapshotState(FunctionSnapshotContext context) throws Exception {

		}

		@SuppressWarnings("unchecked")
		@Override
		public void initializeState(FunctionInitializationContext context) throws Exception {
			pojoClass = getRuntimeContext().getUserCodeClassLoader().loadClass(POJO_NAME);

			fieldA = pojoClass.getDeclaredField("a");
			fieldA.setAccessible(true);

			if (hasBField) {
				fieldB = pojoClass.getDeclaredField("b");
				fieldB.setAccessible(true);
			}

			if (keyed) {
				keyedValueState = context.getKeyedStateStore().getState(
					new ValueStateDescriptor<>("keyedValueState", (Class<Object>) pojoClass));
				keyedListState = context.getKeyedStateStore().getListState(
					new ListStateDescriptor<>("keyedListState", (Class<Object>) pojoClass));

				ReduceFunction<Object> reduceFunction = new FirstValueReducer<>();
				keyedReducingState = context.getKeyedStateStore().getReducingState(
					new ReducingStateDescriptor<>("keyedReducingState", reduceFunction, (Class<Object>) pojoClass));
			} else {
				partitionableListState = context.getOperatorStateStore().getListState(
					new ListStateDescriptor<>("partitionableListState", (Class<Object>) pojoClass));
				unionListState = context.getOperatorStateStore().getUnionListState(
					new ListStateDescriptor<>("unionListState", (Class<Object>) pojoClass));
			}
		}
	}

	private static final class FirstValueReducer<T> implements ReduceFunction<T> {

		private static final long serialVersionUID = -9222976423336835926L;

		@Override
		public T reduce(T value1, T value2) throws Exception {
			return value1;
		}
	}

	private static final class IdentityKeySelector<T> implements KeySelector<T, T> {

		private static final long serialVersionUID = -3263628393881929147L;

		@Override
		public T getKey(T value) throws Exception {
			return value;
		}
	}
}