/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * 	http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package com.alibaba.flink.connectors.common.source;

import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.PojoField;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.io.InputSplit;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.ParameterizedType;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * AbstractDynamicParallelSource.
 * @param <T>
 * @param <CURSOR>
 */
public abstract class AbstractDynamicParallelSource<T, CURSOR extends Serializable>
		extends AbstractParallelSourceBase<T, CURSOR>
		implements CheckpointedFunction {
	private static final Logger LOG = LoggerFactory.getLogger(AbstractDynamicParallelSource.class);
	private static final long serialVersionUID = -7848357196819780804L;
	private static final String SOURCE_STATE_NAME = "source_offsets_state_name";
	private transient ListState<InnerProgress<CURSOR>> unionInitialProgress;
	private transient List<InnerProgress<CURSOR>> allSplitsInCP;

	protected transient List<Tuple2<InputSplit, CURSOR>> reservedProgress;

	public AbstractDynamicParallelSource() {
		super();
	}

	public abstract List<Tuple2<InputSplit, CURSOR>> reAssignInputSplitsForCurrentSubTask(
			int numberOfParallelSubTasks,
			int indexOfThisSubTask,
			List<InnerProgress<CURSOR>> allSplitsInState) throws IOException;

	/**
	 * Used to deal with situation where some state needed to reserve.
	 * @param numberOfParallelSubTasks
	 * @param indexOfThisSubTask
	 * @param allSplitsInState
	 * @return the split list
	 * @throws IOException
	 */
	public List<Tuple2<InputSplit, CURSOR>> reserveInputSplitsForCurrentSubTask(
			int numberOfParallelSubTasks,
			int indexOfThisSubTask,
			List<InnerProgress<CURSOR>> allSplitsInState) throws IOException{
		List<Tuple2<InputSplit, CURSOR>> result = new ArrayList<>();
		return result;
	}

	protected void createParallelReader(Configuration config) throws IOException {
		if (isRecoryFromState()) {
			LOG.info("Reocory State!");
			initialProgress = reAssignInputSplitsForCurrentSubTask(getRuntimeContext().getNumberOfParallelSubtasks(),
																getRuntimeContext().getIndexOfThisSubtask(),
																allSplitsInCP);
			reservedProgress = reserveInputSplitsForCurrentSubTask(getRuntimeContext().getNumberOfParallelSubtasks(),
																getRuntimeContext().getIndexOfThisSubtask(),
																allSplitsInCP);
		}
		super.createParallelReader(config);
	}

	@Override
	public void initializeState(FunctionInitializationContext context) throws Exception {
		LOG.info("initializeState");
		ParameterizedType p = (ParameterizedType) this.getClass().getGenericSuperclass();
		TypeInformation type0 = TypeExtractor.createTypeInfo(InputSplit.class);
		TypeInformation type1 = TypeExtractor.createTypeInfo(p.getActualTypeArguments()[1]);
//		TypeInformation<Tuple2<InputSplit, CURSOR>> stateTypeInfo = new TupleTypeInfo<>(type0, type1);
		List<PojoField> pojoFields = new ArrayList<>();
		pojoFields.add(new PojoField(InnerProgress.class.getField("inputSplit"), type0));
		pojoFields.add(new PojoField(InnerProgress.class.getField("cursor"), type1));
		TypeInformation<InnerProgress> stateTypeInfo = new PojoTypeInfo<>(InnerProgress.class, pojoFields);

//		ListStateDescriptor<Tuple2<InputSplit, CURSOR>> descriptor = new ListStateDescriptor<>(SOURCE_STATE_NAME, stateTypeInfo);
		ListStateDescriptor<InnerProgress<CURSOR>> descriptor = new ListStateDescriptor(SOURCE_STATE_NAME, stateTypeInfo);
		unionInitialProgress = context.getOperatorStateStore().getUnionListState(descriptor);
		LOG.info("Restoring state: {}", unionInitialProgress);
		allSplitsInCP = new ArrayList<>();
		if (context.isRestored()) {
			recoryFromState = true;
			for (InnerProgress progress: unionInitialProgress.get()){
				allSplitsInCP.add(new InnerProgress(progress.inputSplit, progress.cursor));
			}
		}
	}

	@Override
	public void snapshotState(FunctionSnapshotContext context) throws Exception {
		if (disableParallelRead) {
			return;
		}
		unionInitialProgress.clear();

		// partition with progress
		Set<InputSplit> partitionWithState = new HashSet<>();
		for (Map.Entry<InputSplit, CURSOR> entry : parallelReader.getProgress().getProgress().entrySet()) {
			unionInitialProgress.add(new InnerProgress(entry.getKey(), entry.getValue()));
			partitionWithState.add(entry.getKey());
		}

		// partition without progress
		for (Tuple2<InputSplit, CURSOR> entry : initialProgress) {
			if (!partitionWithState.contains(entry.f0)) {
				unionInitialProgress.add(new InnerProgress(entry.f0, entry.f1));
			}
		}

		if (null != reservedProgress) {
			// reserved partition progress
			for (Tuple2<InputSplit, CURSOR> entry : reservedProgress) {
				if (!partitionWithState.contains(entry.f0)) {
					unionInitialProgress.add(new InnerProgress(entry.f0, entry.f1));
				}
			}
		}
	}

	/**
	 * InnerProgress.
	 * @param <CURSOR>
	 */
	public static class InnerProgress<CURSOR extends Serializable> implements Serializable {

		private static final long serialVersionUID = -7756210303146639268L;
		public InputSplit inputSplit;
		public CURSOR cursor;

		public InnerProgress() {
		}

		public InnerProgress(InputSplit inputSplit, CURSOR cursor) {
			this.inputSplit = inputSplit;
			this.cursor = cursor;
		}

		public InputSplit getInputSplit() {
			return inputSplit;
		}

		public InnerProgress setInputSplit(InputSplit inputSplit) {
			this.inputSplit = inputSplit;
			return this;
		}

		public CURSOR getCursor() {
			return cursor;
		}

		public InnerProgress setCursor(CURSOR cursor) {
			this.cursor = cursor;
			return this;
		}

		@Override
		public String toString() {
			return "InnerProgress{" + "inputSplit=" + inputSplit + ", cursor=" + cursor + '}';
		}
	}
}