/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.optimizer.postpass;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.GenericDataSourceBase;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.base.BulkIterationBase;
import org.apache.flink.api.common.operators.base.DeltaIterationBase;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory;
import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory;
import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory;
import org.apache.flink.optimizer.CompilerException;
import org.apache.flink.optimizer.CompilerPostPassException;
import org.apache.flink.optimizer.plan.BulkIterationPlanNode;
import org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.DualInputPlanNode;
import org.apache.flink.optimizer.plan.NAryUnionPlanNode;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.PlanNode;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.plan.SolutionSetPlanNode;
import org.apache.flink.optimizer.plan.SourcePlanNode;
import org.apache.flink.optimizer.plan.WorksetIterationPlanNode;
import org.apache.flink.optimizer.plan.WorksetPlanNode;
import org.apache.flink.optimizer.util.NoOpUnaryUdfOp;
import org.apache.flink.runtime.operators.DriverStrategy;

/**
 * The post-optimizer plan traversal. This traversal fills in the API specific utilities (serializers and
 * comparators).
 */
public class JavaApiPostPass implements OptimizerPostPass {
	
	private final Set<PlanNode> alreadyDone = new HashSet<PlanNode>();

	private ExecutionConfig executionConfig = null;
	
	@Override
	public void postPass(OptimizedPlan plan) {

		executionConfig = plan.getOriginalPlan().getExecutionConfig();

		for (SinkPlanNode sink : plan.getDataSinks()) {
			traverse(sink);
		}
	}
	

	protected void traverse(PlanNode node) {
		if (!alreadyDone.add(node)) {
			// already worked on that one
			return;
		}
		
		// distinguish the node types
		if (node instanceof SinkPlanNode) {
			// descend to the input channel
			SinkPlanNode sn = (SinkPlanNode) node;
			Channel inchannel = sn.getInput();
			traverseChannel(inchannel);
		}
		else if (node instanceof SourcePlanNode) {
			TypeInformation<?> typeInfo = getTypeInfoFromSource((SourcePlanNode) node);
			((SourcePlanNode) node).setSerializer(createSerializer(typeInfo));
		}
		else if (node instanceof BulkIterationPlanNode) {
			BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node;

			if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node.");
			}
			
			// traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion
			if (iterationNode.getRootOfTerminationCriterion() != null) {
				SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion();
				traverseChannel(addMapper.getInput());
			}

			BulkIterationBase<?> operator = (BulkIterationBase<?>) iterationNode.getProgramOperator();

			// set the serializer
			iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType()));

			// done, we can now propagate our info down
			traverseChannel(iterationNode.getInput());
			traverse(iterationNode.getRootOfStepFunction());
		}
		else if (node instanceof WorksetIterationPlanNode) {
			WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node;
			
			if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node.");
			}
			if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) {
				throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node.");
			}
			
			DeltaIterationBase<?, ?> operator = (DeltaIterationBase<?, ?>) iterationNode.getProgramOperator();
			
			// set the serializers and comparators for the workset iteration
			iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType()));
			iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType()));
			iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(),
					iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null)));
			
			// traverse the inputs
			traverseChannel(iterationNode.getInput1());
			traverseChannel(iterationNode.getInput2());
			
			// traverse the step function
			traverse(iterationNode.getSolutionSetDeltaPlanNode());
			traverse(iterationNode.getNextWorkSetPlanNode());
		}
		else if (node instanceof SingleInputPlanNode) {
			SingleInputPlanNode sn = (SingleInputPlanNode) node;
			
			if (!(sn.getOptimizerNode().getOperator() instanceof SingleInputOperator)) {
				
				// Special case for delta iterations
				if(sn.getOptimizerNode().getOperator() instanceof NoOpUnaryUdfOp) {
					traverseChannel(sn.getInput());
					return;
				} else {
					throw new RuntimeException("Wrong operator type found in post pass.");
				}
			}
			
			SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getOperator();
			
			// parameterize the node's driver strategy
			for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) {
				sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(i),
						getSortOrders(sn.getKeys(i), sn.getSortOrders(i))), i);
			}
			// done, we can now propagate our info down
			traverseChannel(sn.getInput());
			
			// don't forget the broadcast inputs
			for (Channel c: sn.getBroadcastInputs()) {
				traverseChannel(c);
			}
		}
		else if (node instanceof DualInputPlanNode) {
			DualInputPlanNode dn = (DualInputPlanNode) node;
			
			if (!(dn.getOptimizerNode().getOperator() instanceof DualInputOperator)) {
				throw new RuntimeException("Wrong operator type found in post pass.");
			}
			
			DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getOperator();
			
			// parameterize the node's driver strategy
			if (dn.getDriverStrategy().getNumRequiredComparators() > 0) {
				dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(),
					getSortOrders(dn.getKeysForInput1(), dn.getSortOrders())));
				dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(),
						getSortOrders(dn.getKeysForInput2(), dn.getSortOrders())));

				dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(),
						dualInputOperator.getOperatorInfo().getSecondInputType()));
				
			}
						
			traverseChannel(dn.getInput1());
			traverseChannel(dn.getInput2());
			
			// don't forget the broadcast inputs
			for (Channel c: dn.getBroadcastInputs()) {
				traverseChannel(c);
			}
			
		}
		// catch the sources of the iterative step functions
		else if (node instanceof BulkPartialSolutionPlanNode ||
				node instanceof SolutionSetPlanNode ||
				node instanceof WorksetPlanNode)
		{
			// Do nothing :D
		}
		else if (node instanceof NAryUnionPlanNode){
			// Traverse to all child channels
			for (Channel channel : node.getInputs()) {
				traverseChannel(channel);
			}
		}
		else {
			throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName());
		}
	}
	
	private void traverseChannel(Channel channel) {
		
		PlanNode source = channel.getSource();
		Operator<?> javaOp = source.getProgramOperator();
		
//		if (!(javaOp instanceof BulkIteration) && !(javaOp instanceof JavaPlanNode)) {
//			throw new RuntimeException("Wrong operator type found in post pass: " + javaOp);
//		}

		TypeInformation<?> type = javaOp.getOperatorInfo().getOutputType();


		if(javaOp instanceof GroupReduceOperatorBase &&
				(source.getDriverStrategy() == DriverStrategy.SORTED_GROUP_COMBINE || source.getDriverStrategy() == DriverStrategy.ALL_GROUP_REDUCE_COMBINE)) {
			GroupReduceOperatorBase<?, ?, ?> groupNode = (GroupReduceOperatorBase<?, ?, ?>) javaOp;
			type = groupNode.getInput().getOperatorInfo().getOutputType();
		}
		else if(javaOp instanceof PlanUnwrappingReduceGroupOperator &&
				source.getDriverStrategy().equals(DriverStrategy.SORTED_GROUP_COMBINE)) {
			PlanUnwrappingReduceGroupOperator<?, ?, ?> groupNode = (PlanUnwrappingReduceGroupOperator<?, ?, ?>) javaOp;
			type = groupNode.getInput().getOperatorInfo().getOutputType();
		}
		
		// the serializer always exists
		channel.setSerializer(createSerializer(type));
			
		// parameterize the ship strategy
		if (channel.getShipStrategy().requiresComparator()) {
			channel.setShipStrategyComparator(createComparator(type, channel.getShipStrategyKeys(), 
				getSortOrders(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder())));
		}
			
		// parameterize the local strategy
		if (channel.getLocalStrategy().requiresComparator()) {
			channel.setLocalStrategyComparator(createComparator(type, channel.getLocalStrategyKeys(),
				getSortOrders(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder())));
		}
		
		// descend to the channel's source
		traverse(channel.getSource());
	}
	
	
	@SuppressWarnings("unchecked")
	private static <T> TypeInformation<T> getTypeInfoFromSource(SourcePlanNode node) {
		Operator<?> op = node.getOptimizerNode().getOperator();
		
		if (op instanceof GenericDataSourceBase) {
			return ((GenericDataSourceBase<T, ?>) op).getOperatorInfo().getOutputType();
		} else {
			throw new RuntimeException("Wrong operator type found in post pass.");
		}
	}
	
	private <T> TypeSerializerFactory<?> createSerializer(TypeInformation<T> typeInfo) {
		TypeSerializer<T> serializer = typeInfo.createSerializer(executionConfig);

		return new RuntimeSerializerFactory<T>(serializer, typeInfo.getTypeClass());
	}
	
	@SuppressWarnings("unchecked")
	private <T> TypeComparatorFactory<?> createComparator(TypeInformation<T> typeInfo, FieldList keys, boolean[] sortOrder) {
		
		TypeComparator<T> comparator;
		if (typeInfo instanceof CompositeType) {
			comparator = ((CompositeType<T>) typeInfo).createComparator(keys.toArray(), sortOrder, 0, executionConfig);
		}
		else if (typeInfo instanceof AtomicType) {
			// handle grouping of atomic types
			comparator = ((AtomicType<T>) typeInfo).createComparator(sortOrder[0], executionConfig);
		}
		else {
			throw new RuntimeException("Unrecognized type: " + typeInfo);
		}

		return new RuntimeComparatorFactory<T>(comparator);
	}
	
	private static <T1 extends Tuple, T2 extends Tuple> TypePairComparatorFactory<T1,T2> createPairComparator(TypeInformation<?> typeInfo1, TypeInformation<?> typeInfo2) {
//		@SuppressWarnings("unchecked")
//		TupleTypeInfo<T1> info1 = (TupleTypeInfo<T1>) typeInfo1;
//		@SuppressWarnings("unchecked")
//		TupleTypeInfo<T2> info2 = (TupleTypeInfo<T2>) typeInfo2;
		
		return new RuntimePairComparatorFactory<T1,T2>();
	}
	
	private static final boolean[] getSortOrders(FieldList keys, boolean[] orders) {
		if (orders == null) {
			orders = new boolean[keys.size()];
			Arrays.fill(orders, true);
		}
		return orders;
	}
}