/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.king.bravo.writer;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.function.BiConsumer;

import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.savepoint.Savepoint;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointedStateScope;
import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.filesystem.FileBasedStateOutputStream;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
import org.apache.flink.shaded.guava18.com.google.common.collect.Maps;

import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import com.king.bravo.reader.OperatorStateReader;
import com.king.bravo.types.KeyedStateRow;
import com.king.bravo.utils.StateMetadataUtils;
import com.king.bravo.writer.functions.KeyGroupAndStateNameKey;
import com.king.bravo.writer.functions.OperatorIndexForKeyGroupKey;
import com.king.bravo.writer.functions.RocksDBSavepointWriter;
import com.king.bravo.writer.functions.RowFilter;
import com.king.bravo.writer.functions.ValueStateToKeyedStateRow;

/**
 * Utility for creating new OperatorStates based on old checkpointed data and
 * new state from datasets. This can be used to modify add and remove keyed
 * states as well as to modify non-keyed states.
 */
public class OperatorStateWriter {

	private OperatorState baseOpState;
	private DataSet<KeyedStateRow> allRows = null;
	private Path newCheckpointBasePath;
	private long checkpointId;
	private BiConsumer<Integer, OperatorStateBackend> transformer;
	private Map<String, StateMetaInfoSnapshot> metaSnapshots;
	private KeyedBackendSerializationProxy<?> proxy;

	private boolean keepBaseKeyedStates = true;

	private TypeSerializer<?> keySerializer = null;

	public OperatorStateWriter(Savepoint sp, String uid, Path newCheckpointBasePath) {
		this(sp.getCheckpointId(), StateMetadataUtils.getOperatorState(sp, uid), newCheckpointBasePath);
	}

	public OperatorStateWriter(long checkpointId, OperatorState baseOpState, Path newCheckpointBasePath) {
		this.baseOpState = baseOpState;
		this.newCheckpointBasePath = newCheckpointBasePath;
		this.checkpointId = checkpointId;

		proxy = StateMetadataUtils.getKeyedBackendSerializationProxy(baseOpState).orElse(null);
		metaSnapshots = new HashMap<>();
		if (proxy != null) {
			proxy.getStateMetaInfoSnapshots()
					.forEach(ms -> metaSnapshots.put(ms.getName(),
							new StateMetaInfoSnapshot(ms.getName(), ms.getBackendStateType(), ms.getOptionsImmutable(),
									ms.getSerializerSnapshotsImmutable(),
									Maps.transformValues(ms.getSerializerSnapshotsImmutable(),
											TypeSerializerSnapshot::restoreSerializer))));
		}
	}

	/**
	 * Defines the keyserializer for this operator. This method can be used when
	 * adding state to a previously stateless operator where the keyserializer
	 * is not available from the state.
	 * 
	 * @param keySerializer
	 */
	public void setKeySerializer(TypeSerializer<?> keySerializer) {
		this.keySerializer = keySerializer;
	}

	/**
	 * Add a Dataset of {@link KeyedStateRow}s to the state of the operator,
	 * this is mostly used to migrate existing states of the operator to the new
	 * operator state without modifications.
	 * <p>
	 * This can be used to add all different kinds of keyed states: value, list,
	 * map
	 * 
	 * @param rows
	 *            State rows to be added
	 */
	public void addKeyedStateRows(DataSet<KeyedStateRow> rows) {
		allRows = allRows == null ? rows : allRows.union(rows);
		keepBaseKeyedStates = false;
	}

	/**
	 * Removes the state metadata and rows for the given statename.
	 * 
	 * @param stateName
	 *            Name of the state to be deleted
	 */
	public void deleteKeyedState(String stateName) {
		metaSnapshots.remove(stateName);
		keepBaseKeyedStates = false;
	}

	private void updateProxy() {
		if (proxy == null && keySerializer == null) {
			throw new IllegalStateException(
					"KeySerializer must be defined when adding state to a previously stateless operator. Use writer.setKeySerializer(...)");
		}

		proxy = new KeyedBackendSerializationProxy<>(
				getKeySerializer(),
				new ArrayList<>(metaSnapshots.values()),
				proxy != null ? proxy.isUsingKeyGroupCompression() : true);
	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	private <T> TypeSerializer<T> getKeySerializer() {
		return proxy != null ? (TypeSerializer) proxy.getKeySerializerConfigSnapshot().restoreSerializer()
				: (TypeSerializer) keySerializer;
	}

	/**
	 * Adds a dataset of K-V pairs to the keyed state of the operator. This
	 * operation assumes that a state with the same name is already defined and
	 * the metadata is reused.
	 * <p>
	 * To define new states see
	 * {@link #createNewValueState(String, DataSet, TypeSerializer)}
	 * <p>
	 * Keep in mind that any state rows for the same name already added (through
	 * {@link #addKeyedStateRows(DataSet)}) will not be overwritten.
	 * 
	 * @param stateName
	 * @param newState
	 */
	@SuppressWarnings({ "rawtypes", "unchecked" })
	public <K, V> void addValueState(String stateName, DataSet<Tuple2<K, V>> newState) {
		TypeSerializer<V> valueSerializer = (TypeSerializer<V>) (TypeSerializer) StateMetadataUtils
				.getSerializer(proxy, stateName)
				.orElseThrow(
						() -> new IllegalArgumentException("Cannot find state " + stateName));

		if (StateMetadataUtils.isTtlState(valueSerializer)) {
			throw new RuntimeException("Writing of TTL states is not supported at the moment.");
		}
		addKeyedStateRows(newState
				.map(new ValueStateToKeyedStateRow<K, V>(stateName,
						getKeySerializer(),
						valueSerializer,
						baseOpState.getMaxParallelism())));
	}

	/**
	 * Defines/redefines a value state with the given name and type. This can be
	 * used to create new states of an operator or change the type of an already
	 * existing state.
	 * <p>
	 * When redefining a pre-existing state make sure you haven't added that as
	 * keyed state rows before.
	 * 
	 * @param stateName
	 * @param newState
	 * @param valueSerializer
	 */
	public <K, V> void createNewValueState(String stateName, DataSet<Tuple2<K, V>> newState,
			TypeSerializer<V> valueSerializer) {

		metaSnapshots.put(stateName, new RegisteredKeyValueStateBackendMetaInfo<>(StateDescriptor.Type.VALUE, stateName,
				VoidNamespaceSerializer.INSTANCE, valueSerializer).snapshot());

		updateProxy();
		addKeyedStateRows(newState
				.map(new ValueStateToKeyedStateRow<K, V>(stateName,
						getKeySerializer(),
						valueSerializer,
						baseOpState.getMaxParallelism())));
	}

	/**
	 * Triggers the batch processing operations to write the operator state data
	 * to persistent storage and create the metadata object
	 * 
	 * @return {@link OperatorState} metadata pointing to the newly written
	 *         state
	 */
	public OperatorState writeAll() throws Exception {
		int maxParallelism = baseOpState.getMaxParallelism();
		int parallelism = baseOpState.getParallelism();
		Path outDir = makeOutputDir();

		Map<Integer, KeyedStateHandle> handleMap = new HashMap<>();
		if (allRows == null) {
			if (!keepBaseKeyedStates && !metaSnapshots.isEmpty()) {
				throw new IllegalStateException(
						"States must be added when any modification of existing keyed states were made");
			} else {
				// Either keep all state or delete all (no need to do anything
				// here)
			}
		} else if (!metaSnapshots.isEmpty()) {
			updateProxy();
			ByteArrayOutputStream os = new ByteArrayOutputStream();
			DataOutputView bow = new DataOutputViewStreamWrapper(os);
			proxy.write(bow);

			DataSet<Tuple2<Integer, KeyedStateHandle>> handles = allRows
					.filter(new RowFilter(metaSnapshots.keySet()))
					.groupBy(new OperatorIndexForKeyGroupKey(maxParallelism, parallelism))
					.sortGroup(new KeyGroupAndStateNameKey(maxParallelism), Order.ASCENDING)
					.reduceGroup(new RocksDBSavepointWriter(maxParallelism, parallelism,
							HashBiMap.create(StateMetadataUtils.getStateIdMapping(proxy)).inverse(),
							proxy.isUsingKeyGroupCompression(), outDir, os.toByteArray()));

			handles.collect().stream().forEach(t -> handleMap.put(t.f0, t.f1));
		} else {
			throw new IllegalStateException(
					"There are state rows but no state metadata... maybe you meant to use createNewValueState(...)");
		}

		// We construct a new operatorstate with the collected handles
		OperatorState newOperatorState = new OperatorState(baseOpState.getOperatorID(), parallelism, maxParallelism);

		// Fill with the subtaskstates based on the old one (we need to preserve
		// the
		// other states)
		baseOpState.getSubtaskStates().forEach((subtaskId, subtaskState) -> {
			KeyedStateHandle newKeyedHandle = handleMap.get(subtaskId);
			StateObjectCollection<OperatorStateHandle> opHandle = transformSubtaskOpState(outDir, subtaskId,
					subtaskState.getManagedOperatorState());

			newOperatorState.putState(subtaskId,
					new OperatorSubtaskState(
							opHandle,
							subtaskState.getRawOperatorState(),
							keepBaseKeyedStates ? subtaskState.getManagedKeyedState()
									: new StateObjectCollection<>(
											newKeyedHandle != null
													? Lists.newArrayList(newKeyedHandle)
													: Collections.emptyList()),
							subtaskState.getRawKeyedState()));
		});

		return newOperatorState;
	}

	private StateObjectCollection<OperatorStateHandle> transformSubtaskOpState(Path outDir, Integer subtaskId,
			StateObjectCollection<OperatorStateHandle> baseState) {

		if (transformer == null) {
			return baseState;
		}

		StateObjectCollection<OperatorStateHandle> opHandle = baseState;
		try (OperatorStateBackend opBackend = OperatorStateReader
				.restoreOperatorStateBackend(opHandle)) {

			transformer.accept(subtaskId, opBackend);

			OperatorStateHandle newSnapshot = opBackend
					.snapshot(checkpointId, System.currentTimeMillis(), new CheckpointStreamFactory() {
						@Override
						public CheckpointStateOutputStream createCheckpointStateOutputStream(
								CheckpointedStateScope scope)
								throws IOException {
							return new FileBasedStateOutputStream(outDir.getFileSystem(),
									new Path(outDir, String.valueOf(UUID.randomUUID())));
						}
					}, null).get().getJobManagerOwnedSnapshot();
			return new StateObjectCollection<>(Lists.newArrayList(newSnapshot));
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	/**
	 * Transform the non-keyed state of the operator by applying a function to
	 * the non-keyed state of each operator instance. Any update made to the
	 * {@link OperatorStateBackend} will be stored back as the new state of the
	 * operator.
	 * <p>
	 * The transformation will be executed sequentially, in-memory on the
	 * client.
	 * 
	 * @param transformer
	 *            Consumer to be applied on the {@link OperatorStateBackend}
	 * @throws Exception
	 */
	public <K, V> void transformNonKeyedState(BiConsumer<Integer, OperatorStateBackend> transformer) throws Exception {
		this.transformer = transformer;
	}

	private Path makeOutputDir() {
		final Path outDir = new Path(new Path(newCheckpointBasePath, "mchk-" + checkpointId),
				"op-" + baseOpState.getOperatorID());
		try {
			outDir.getFileSystem().mkdirs(outDir);
		} catch (IOException ignore) {}
		return outDir;
	}

}