/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.functions.python;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ConfigurationUtils;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.python.env.ProcessPythonEnvironmentManager;
import org.apache.flink.python.env.PythonDependencyInfo;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.python.metric.FlinkMetricContainer;
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter;
import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
 * Base Python stateless {@link RichFlatMapFunction} used to invoke Python stateless functions for the
 * old planner.
 */
@Internal
public abstract class AbstractPythonStatelessFunctionFlatMap
	extends RichFlatMapFunction<Row, Row> implements ResultTypeQueryable<Row> {

	private static final long serialVersionUID = 1L;

	private static final Logger LOG = LoggerFactory.getLogger(AbstractPythonStatelessFunctionFlatMap.class);

	/**
	 * The python config.
	 */
	private final PythonConfig config;

	/**
	 * The offsets of user-defined function inputs.
	 */
	private final int[] userDefinedFunctionInputOffsets;

	/**
	 * The input logical type.
	 */
	protected final RowType inputType;

	/**
	 * The output logical type.
	 */
	protected final RowType outputType;

	/**
	 * The options used to configure the Python worker process.
	 */
	protected final Map<String, String> jobOptions;

	/**
	 * The user-defined function input logical type.
	 */
	protected transient RowType userDefinedFunctionInputType;

	/**
	 * The user-defined function output logical type.
	 */
	protected transient RowType userDefinedFunctionOutputType;

	/**
	 * The queue holding the input elements for which the execution results have not been received.
	 */
	protected transient LinkedBlockingQueue<Row> forwardedInputQueue;

	/**
	 * The queue holding the user-defined function execution results. The execution results
	 * are in the same order as the input elements.
	 */
	protected transient LinkedBlockingQueue<byte[]> userDefinedFunctionResultQueue;

	/**
	 * Use an AtomicBoolean because we start/stop bundles by a timer thread.
	 */
	private transient AtomicBoolean bundleStarted;

	/**
	 * Max number of elements to include in a bundle.
	 */
	private transient int maxBundleSize;

	/**
	 * The collector used to collect records.
	 */
	protected transient Collector<Row> resultCollector;

	/**
	 * Number of processed elements in the current bundle.
	 */
	private transient int elementCount;

	/**
	 * The {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
	 */
	private transient PythonFunctionRunner<Row> pythonFunctionRunner;

	/**
	 * Reusable InputStream used to holding the execution results to be deserialized.
	 */
	protected transient ByteArrayInputStreamWithPos bais;

	/**
	 * InputStream Wrapper.
	 */
	protected transient DataInputViewStreamWrapper baisWrapper;

	/**
	 * The TypeSerializer for user-defined function execution results.
	 */
	protected transient TypeSerializer<Row> userDefinedFunctionTypeSerializer;

	/**
	 * The type serializer for the forwarded fields.
	 */
	protected transient TypeSerializer<Row> forwardedInputSerializer;

	public AbstractPythonStatelessFunctionFlatMap(
		Configuration config,
		RowType inputType,
		RowType outputType,
		int[] userDefinedFunctionInputOffsets) {
		this.inputType = Preconditions.checkNotNull(inputType);
		this.outputType = Preconditions.checkNotNull(outputType);
		this.userDefinedFunctionInputOffsets = Preconditions.checkNotNull(userDefinedFunctionInputOffsets);
		this.config = new PythonConfig(Preconditions.checkNotNull(config));
		this.jobOptions = buildJobOptions(config);
	}

	protected PythonConfig getPythonConfig() {
		return config;
	}

	@Override
	@SuppressWarnings("unchecked")
	public void open(Configuration parameters) throws Exception {
		super.open(parameters);

		this.elementCount = 0;
		this.bundleStarted = new AtomicBoolean(false);
		this.maxBundleSize = config.getMaxBundleSize();
		if (this.maxBundleSize <= 0) {
			this.maxBundleSize = PythonOptions.MAX_BUNDLE_SIZE.defaultValue();
			LOG.error("Invalid value for the maximum bundle size. Using default value of " +
				this.maxBundleSize + '.');
		} else {
			LOG.info("The maximum bundle size is configured to {}.", this.maxBundleSize);
		}

		if (config.getMaxBundleTimeMills() != PythonOptions.MAX_BUNDLE_TIME_MILLS.defaultValue()) {
			LOG.info("Maximum bundle time takes no effect in old planner under batch mode. " +
				"Config maximum bundle size instead! " +
				"Under batch mode, bundle size should be enough to control both throughput and latency.");
		}

		forwardedInputQueue = new LinkedBlockingQueue<>();
		userDefinedFunctionResultQueue = new LinkedBlockingQueue<>();
		userDefinedFunctionInputType = new RowType(
			Arrays.stream(userDefinedFunctionInputOffsets)
				.mapToObj(i -> inputType.getFields().get(i))
				.collect(Collectors.toList()));
		bais = new ByteArrayInputStreamWithPos();
		baisWrapper = new DataInputViewStreamWrapper(bais);

		userDefinedFunctionOutputType = new RowType(outputType.getFields().subList(getForwardedFieldsCount(), outputType.getFieldCount()));
		userDefinedFunctionTypeSerializer = PythonTypeUtils.toFlinkTypeSerializer(userDefinedFunctionOutputType);

		this.pythonFunctionRunner = createPythonFunctionRunner();
		this.pythonFunctionRunner.open();
	}

	@Override
	public void flatMap(Row value, Collector<Row> out) throws Exception {
		this.resultCollector = out;
		bufferInput(value);

		checkInvokeStartBundle();
		pythonFunctionRunner.processElement(getFunctionInput(value));
		checkInvokeFinishBundleByCount();
		emitResults();
	}

	@Override
	public TypeInformation<Row> getProducedType() {
		return (TypeInformation<Row>) LegacyTypeInfoDataTypeConverter
			.toLegacyTypeInfo(LogicalTypeDataTypeConverter.toDataType(outputType));
	}

	@Override
	public void close() throws Exception {
		try {
			invokeFinishBundle();

			if (pythonFunctionRunner != null) {
				pythonFunctionRunner.close();
				pythonFunctionRunner = null;
			}
		} finally {
			super.close();
		}
	}

	/**
	 * Returns the {@link PythonEnv} used to create PythonEnvironmentManager..
	 */
	public abstract PythonEnv getPythonEnv();

	public abstract PythonFunctionRunner<Row> createPythonFunctionRunner() throws IOException;

	public abstract void bufferInput(Row input);

	public abstract void emitResults() throws IOException;

	public abstract int getForwardedFieldsCount();

	protected PythonEnvironmentManager createPythonEnvironmentManager() throws IOException {
		PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create(
			config, getRuntimeContext().getDistributedCache());
		PythonEnv pythonEnv = getPythonEnv();
		if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) {
			return new ProcessPythonEnvironmentManager(
				dependencyInfo,
				ConfigurationUtils.splitPaths(System.getProperty("java.io.tmpdir")),
				System.getenv());
		} else {
			throw new UnsupportedOperationException(String.format(
				"Execution type '%s' is not supported.", pythonEnv.getExecType()));
		}
	}

	protected FlinkMetricContainer getFlinkMetricContainer() {
		return this.config.isMetricEnabled() ?
			new FlinkMetricContainer(getRuntimeContext().getMetricGroup()) : null;
	}

	/**
	 * Checks whether to invoke startBundle.
	 */
	private void checkInvokeStartBundle() throws Exception {
		if (bundleStarted.compareAndSet(false, true)) {
			pythonFunctionRunner.startBundle();
		}
	}

	/**
	 * Checks whether to invoke finishBundle by elements count. Called in flatMap.
	 */
	private void checkInvokeFinishBundleByCount() throws Exception {
		elementCount++;
		if (elementCount >= maxBundleSize) {
			invokeFinishBundle();
		}
	}

	private void invokeFinishBundle() throws Exception {
		if (bundleStarted.compareAndSet(true, false)) {
			pythonFunctionRunner.finishBundle();
			emitResults();
			elementCount = 0;
		}
	}

	private Row getFunctionInput(Row element) {
		return Row.project(element, userDefinedFunctionInputOffsets);
	}

	private Map<String, String> buildJobOptions(Configuration config) {
		Map<String, String> jobOptions = new HashMap<>();
		if (config.containsKey("table.exec.timezone")) {
			jobOptions.put("table.exec.timezone", config.getString("table.exec.timezone", null));
		}
		return jobOptions;
	}
}