/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.java.utils;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.summarize.aggregation.SummaryAggregatorFactory;
import org.apache.flink.api.java.summarize.aggregation.TupleSummaryAggregator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * This class provides simple utility methods for zipping elements in a data set with an index
 * or with a unique identifier.
 */
@PublicEvolving
public final class DataSetUtils {

	/**
	 * Method that goes over all the elements in each partition in order to retrieve
	 * the total number of elements.
	 *
	 * @param input the DataSet received as input
	 * @return a data set containing tuples of subtask index, number of elements mappings.
	 */
	public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
			@Override
			public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
				long counter = 0;
				for (T value : values) {
					counter++;
				}
				out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
			}
		});
	}

	/**
	 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
	 * consecutive.
	 *
	 * @param input the input data set
	 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
	 */
	public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

		DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

			long start = 0;

			@Override
			public void open(Configuration parameters) throws Exception {
				super.open(parameters);

				List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
						"counts",
						new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
							@Override
							public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
								// sort the list by task id to calculate the correct offset
								List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
								for (Tuple2<Integer, Long> datum : data) {
									sortedData.add(datum);
								}
								Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
									@Override
									public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
										return o1.f0.compareTo(o2.f0);
									}
								});
								return sortedData;
							}
						});

				// compute the offset for each partition
				for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
					start += offsets.get(i).f1;
				}
			}

			@Override
			public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
				for (T value: values) {
					out.collect(new Tuple2<>(start++, value));
				}
			}
		}).withBroadcastSet(elementCount, "counts");
	}

	/**
	 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
	 * <ul>
	 *  <li> a map function is applied to the input data set
	 *  <li> each map task holds a counter c which is increased for each record
	 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
	 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
	 * 	<li> for each record, the resulting counter is collected
	 * </ul>
	 *
	 * @param input the input data set
	 * @return a data set of tuple 2 consisting of ids and initial values.
	 */
	public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

		return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

			long maxBitSize = getBitSize(Long.MAX_VALUE);
			long shifter = 0;
			long start = 0;
			long taskId = 0;
			long label = 0;

			@Override
			public void open(Configuration parameters) throws Exception {
				super.open(parameters);
				shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
				taskId = getRuntimeContext().getIndexOfThisSubtask();
			}

			@Override
			public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
				for (T value : values) {
					label = (start << shifter) + taskId;

					if (getBitSize(start) + shifter < maxBitSize) {
						out.collect(new Tuple2<>(label, value));
						start++;
					} else {
						throw new Exception("Exceeded Long value range while generating labels");
					}
				}
			}
		});
	}

	// --------------------------------------------------------------------------------------------
	//  Sample
	// --------------------------------------------------------------------------------------------

	/**
	 * Generate a sample of DataSet by the probability fraction of each element.
	 *
	 * @param withReplacement Whether element can be selected more than once.
	 * @param fraction        Probability that each element is chosen, should be [0,1] without replacement,
	 *                        and [0, ∞) with replacement. While fraction is larger than 1, the elements are
	 *                        expected to be selected multi times into sample on average.
	 * @return The sampled DataSet
	 */
	public static <T> MapPartitionOperator<T, T> sample(
		DataSet <T> input,
		final boolean withReplacement,
		final double fraction) {

		return sample(input, withReplacement, fraction, Utils.RNG.nextLong());
	}

	/**
	 * Generate a sample of DataSet by the probability fraction of each element.
	 *
	 * @param withReplacement Whether element can be selected more than once.
	 * @param fraction        Probability that each element is chosen, should be [0,1] without replacement,
	 *                        and [0, ∞) with replacement. While fraction is larger than 1, the elements are
	 *                        expected to be selected multi times into sample on average.
	 * @param seed            random number generator seed.
	 * @return The sampled DataSet
	 */
	public static <T> MapPartitionOperator<T, T> sample(
		DataSet <T> input,
		final boolean withReplacement,
		final double fraction,
		final long seed) {

		return input.mapPartition(new SampleWithFraction<T>(withReplacement, fraction, seed));
	}

	/**
	 * Generate a sample of DataSet which contains fixed size elements.
	 *
	 * <p><strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with
	 * fraction unless you need exact precision.
	 *
	 * @param withReplacement Whether element can be selected more than once.
	 * @param numSamples       The expected sample size.
	 * @return The sampled DataSet
	 */
	public static <T> DataSet<T> sampleWithSize(
		DataSet <T> input,
		final boolean withReplacement,
		final int numSamples) {

		return sampleWithSize(input, withReplacement, numSamples, Utils.RNG.nextLong());
	}

	/**
	 * Generate a sample of DataSet which contains fixed size elements.
	 *
	 * <p><strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with
	 * fraction unless you need exact precision.
	 *
	 * @param withReplacement Whether element can be selected more than once.
	 * @param numSamples       The expected sample size.
	 * @param seed            Random number generator seed.
	 * @return The sampled DataSet
	 */
	public static <T> DataSet<T> sampleWithSize(
		DataSet <T> input,
		final boolean withReplacement,
		final int numSamples,
		final long seed) {

		SampleInPartition<T> sampleInPartition = new SampleInPartition<>(withReplacement, numSamples, seed);
		MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition);

		// There is no previous group, so the parallelism of GroupReduceOperator is always 1.
		String callLocation = Utils.getCallLocationName();
		SampleInCoordinator<T> sampleInCoordinator = new SampleInCoordinator<>(withReplacement, numSamples, seed);
		return new GroupReduceOperator<>(mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation);
	}

	// --------------------------------------------------------------------------------------------
	//  Partition
	// --------------------------------------------------------------------------------------------

	/**
	 * Range-partitions a DataSet on the specified tuple field positions.
	 */
	public static <T> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, int... fields) {
		return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType(), false), distribution, Utils.getCallLocationName());
	}

	/**
	 * Range-partitions a DataSet on the specified fields.
	 */
	public static <T> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, String... fields) {
		return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType()), distribution, Utils.getCallLocationName());
	}

	/**
	 * Range-partitions a DataSet using the specified key selector function.
	 */
	public static <T, K extends Comparable<K>> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, KeySelector<T, K> keyExtractor) {
		final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType());
		return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.SelectorFunctionKeys<>(input.clean(keyExtractor), input.getType(), keyType), distribution, Utils.getCallLocationName());
	}

	// --------------------------------------------------------------------------------------------
	//  Summarize
	// --------------------------------------------------------------------------------------------

	/**
	 * Summarize a DataSet of Tuples by collecting single pass statistics for all columns.
	 *
	 * <p>Example usage:
	 * <pre>
	 * {@code
	 * Dataset<Tuple3<Double, String, Boolean>> input = // [...]
	 * Tuple3<NumericColumnSummary,StringColumnSummary, BooleanColumnSummary> summary = DataSetUtils.summarize(input)
	 *
	 * summary.f0.getStandardDeviation()
	 * summary.f1.getMaxLength()
	 * }
	 * </pre>
	 * @return the summary as a Tuple the same width as input rows
	 */
	public static <R extends Tuple, T extends Tuple> R summarize(DataSet<T> input) throws Exception {
		if (!input.getType().isTupleType()) {
			throw new IllegalArgumentException("summarize() is only implemented for DataSet's of Tuples");
		}
		final TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getType();
		DataSet<TupleSummaryAggregator<R>> result = input.mapPartition(new MapPartitionFunction<T, TupleSummaryAggregator<R>>() {
			@Override
			public void mapPartition(Iterable<T> values, Collector<TupleSummaryAggregator<R>> out) throws Exception {
				TupleSummaryAggregator<R> aggregator = SummaryAggregatorFactory.create(inType);
				for (Tuple value : values) {
					aggregator.aggregate(value);
				}
				out.collect(aggregator);
			}
		}).reduce(new ReduceFunction<TupleSummaryAggregator<R>>() {
			@Override
			public TupleSummaryAggregator<R> reduce(TupleSummaryAggregator<R> agg1, TupleSummaryAggregator<R> agg2) throws Exception {
				agg1.combine(agg2);
				return agg1;
			}
		});
		return result.collect().get(0).result();
	}

	// --------------------------------------------------------------------------------------------
	//  Checksum
	// --------------------------------------------------------------------------------------------

	/**
	 * Convenience method to get the count (number of elements) of a DataSet
	 * as well as the checksum (sum over element hashes).
	 *
	 * @return A ChecksumHashCode that represents the count and checksum of elements in the data set.
	 * @deprecated replaced with {@code org.apache.flink.graph.asm.dataset.ChecksumHashCode} in Gelly
	 */
	@Deprecated
	public static <T> Utils.ChecksumHashCode checksumHashCode(DataSet<T> input) throws Exception {
		final String id = new AbstractID().toString();

		input.output(new Utils.ChecksumHashCodeHelper<T>(id)).name("ChecksumHashCode");

		JobExecutionResult res = input.getExecutionEnvironment().execute();
		return res.<Utils.ChecksumHashCode> getAccumulatorResult(id);
	}

	// *************************************************************************
	//     UTIL METHODS
	// *************************************************************************

	public static int getBitSize(long value) {
		if (value > Integer.MAX_VALUE) {
			return 64 - Integer.numberOfLeadingZeros((int) (value >> 32));
		} else {
			return 32 - Integer.numberOfLeadingZeros((int) value);
		}
	}

	/**
	 * Private constructor to prevent instantiation.
	 */
	private DataSetUtils() {
		throw new RuntimeException();
	}
}