/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package org.apache.flink.graph.streaming;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.windowing.time.Time;

import java.io.Serializable;
import java.util.concurrent.TimeUnit;


/**
 * Graph Tree Aggregation on Parallel Time Window
 * 
 * TODO add documentation
 * 
 */
public class SummaryTreeReduce<K, EV, S extends Serializable, T> extends SummaryBulkAggregation<K, EV, S, T> {

	private static final long serialVersionUID = 1L;
	private int degree;
	

	public SummaryTreeReduce(EdgesFold<K, EV, S> updateFun, ReduceFunction<S> combineFun, MapFunction<S, T> transformFun, S initialVal, long timeMillis, boolean transientState, int degree) {
		super(updateFun, combineFun, transformFun, initialVal, timeMillis, transientState);
		this.degree = degree;
	}

	public SummaryTreeReduce(EdgesFold<K, EV, S> updateFun, ReduceFunction<S> combineFun, S initialVal, long timeMillis, boolean transientState, int degree) {
		this(updateFun, combineFun, null, initialVal, timeMillis, transientState, degree);
	}

	public SummaryTreeReduce(EdgesFold<K, EV, S> updateFun, ReduceFunction<S> combineFun, S initialVal, long timeMillis, boolean transientState) {
		this(updateFun, combineFun, null, initialVal, timeMillis, transientState, -1);
	}

	@SuppressWarnings("unchecked")
	@Override
	public DataStream<T> run(final DataStream<Edge<K, EV>> edgeStream) {
		TypeInformation<Tuple2<Integer, Edge<K, EV>>> basicTypeInfo = new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, edgeStream.getType());

		TupleTypeInfo edgeTypeInfo = (TupleTypeInfo) edgeStream.getType();
		TypeInformation<S> partialAggType = TypeExtractor.createTypeInfo(EdgesFold.class, getUpdateFun().getClass(), 2, edgeTypeInfo.getTypeAt(0), edgeTypeInfo.getTypeAt(2));
		TypeInformation<Tuple2<Integer, S>> partialTypeInfo = new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, partialAggType);

		degree = (degree == -1) ? edgeStream.getParallelism() : degree;
		
		DataStream<S> partialAgg = edgeStream
				.map(new PartitionMapper<>()).returns(basicTypeInfo)
				.setParallelism(degree)
				.keyBy(0)
				.timeWindow(Time.of(timeMillis, TimeUnit.MILLISECONDS))
				.fold(getInitialValue(), new PartialAgg<>(getUpdateFun(), partialAggType)).setParallelism(degree);
		//split here

		DataStream<Tuple2<Integer, S>> treeAgg = enhance(partialAgg.map(new PartitionMapper<>()).setParallelism(degree).returns(partialTypeInfo), partialTypeInfo);

		DataStream<S> resultStream = treeAgg.map(new PartitionStripper<>()).setParallelism(treeAgg.getParallelism())
				.timeWindowAll(Time.of(timeMillis, TimeUnit.MILLISECONDS))
				.reduce(getCombineFun())
				.flatMap(getAggregator(edgeStream)).setParallelism(1);

		return (getTransform() != null) ? resultStream.map(getTransform()) : (DataStream<T>) resultStream;
	}

	private DataStream<Tuple2<Integer, S>> enhance(DataStream<Tuple2<Integer, S>> input, TypeInformation<Tuple2<Integer, S>> aggType) {

		if (input.getParallelism() <= 2) {
			return input;
		}

		int nextParal = input.getParallelism() / 2;
		DataStream<Tuple2<Integer, S>> unpartitionedStream =
				input.keyBy(new KeySelector<Tuple2<Integer, S>, Integer>() {
					//collapse two partitions into one
					@Override
					public Integer getKey(Tuple2<Integer, S> record) throws Exception {
						return record.f0 / 2;
					}
				});

		//repartition stream to p / 2 aggregators
		KeyedStream<Tuple2<Integer, S>, Integer> repartitionedStream =
				unpartitionedStream.map(new PartitionReMapper()).returns(aggType)
						.setParallelism(nextParal)
						.keyBy(0);

		//window again on event time and aggregate
		DataStream<Tuple2<Integer, S>> aggregatedStream =
				repartitionedStream.timeWindow(Time.of(timeMillis, TimeUnit.MILLISECONDS))
						.reduce(new AggregationWrapper<>(getCombineFun()))       
						.setParallelism(nextParal);
		return enhance(aggregatedStream, aggType);
	}

	protected static final class PartitionReMapper<Y> extends RichMapFunction<Tuple2<Integer, Y>, Tuple2<Integer, Y>> {

		private int partitionIndex;

		@Override
		public void open(Configuration parameters) throws Exception {
			this.partitionIndex = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public Tuple2<Integer, Y> map(Tuple2<Integer, Y> tpl) throws Exception {
			return new Tuple2<>(partitionIndex, tpl.f1);
		}
	}

	public static class PartitionStripper<S> implements MapFunction<Tuple2<Integer, S>, S> {
		@Override
		public S map(Tuple2<Integer, S> tpl) throws Exception {
			return tpl.f1;
		}
	}

	public static class AggregationWrapper<S> implements ReduceFunction<Tuple2<Integer, S>> {

		private final ReduceFunction<S> wrappedFunction;

		protected AggregationWrapper(ReduceFunction<S> wrappedFunction) {
			this.wrappedFunction = wrappedFunction;
		}

		@Override
		public Tuple2<Integer, S> reduce(Tuple2<Integer, S> tpl1, Tuple2<Integer, S> tpl2) throws Exception {
			return new Tuple2<>(tpl1.f0, wrappedFunction.reduce(tpl1.f1, tpl2.f1));
		}
	}
}