/*
 * Modifications Copyright 2019 Graz University of Technology
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.tugraz.sysds.runtime.instructions.spark;

import org.apache.commons.lang.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.tugraz.sysds.lops.PickByCount.OperationTypes;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.tugraz.sysds.runtime.instructions.InstructionUtils;
import org.tugraz.sysds.runtime.instructions.cp.CPOperand;
import org.tugraz.sysds.runtime.instructions.cp.DoubleObject;
import org.tugraz.sysds.runtime.instructions.cp.ScalarObject;
import org.tugraz.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.data.MatrixIndexes;
import org.tugraz.sysds.runtime.matrix.operators.Operator;
import org.tugraz.sysds.runtime.meta.DataCharacteristics;
import org.tugraz.sysds.runtime.util.DataConverter;
import org.tugraz.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;

public class QuantilePickSPInstruction extends BinarySPInstruction {
	private OperationTypes _type = null;

	private QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand out, OperationTypes type, boolean inmem,
			String opcode, String istr) {
		this(op, in, null, out, type, inmem, opcode, istr);
	}

	private QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, OperationTypes type,
			boolean inmem, String opcode, String istr) {
		super(SPType.QPick, op, in, in2, out, opcode, istr);
		_type = type;
	}

	public static QuantilePickSPInstruction parseInstruction ( String str ) {
		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
		String opcode = parts[0];
		
		//sanity check opcode
		if ( !opcode.equalsIgnoreCase("qpick") ) {
			throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
		}
		
		//instruction parsing
		if( parts.length == 4 ) {
			//instructions of length 4 originate from unary - mr-iqm
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand out = new CPOperand(parts[3]);
			OperationTypes ptype = OperationTypes.IQM;
			return new QuantilePickSPInstruction(null, in1, in2, out, ptype, false, opcode, str);
		}
		else if( parts.length == 5 ) {
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand out = new CPOperand(parts[2]);
			OperationTypes ptype = OperationTypes.valueOf(parts[3]);
			boolean inmem = Boolean.parseBoolean(parts[4]);
			return new QuantilePickSPInstruction(null, in1, out, ptype, inmem, opcode, str);
		}
		else if( parts.length == 6 ) {
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand out = new CPOperand(parts[3]);
			OperationTypes ptype = OperationTypes.valueOf(parts[4]);
			boolean inmem = Boolean.parseBoolean(parts[5]);
			return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
		}
		
		return null;
	}
	
	@Override
	public void processInstruction(ExecutionContext ec) {
		SparkExecutionContext sec = (SparkExecutionContext)ec;
		
		//get input rdds
		JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
		DataCharacteristics mc = sec.getDataCharacteristics(input1.getName());
		
		//NOTE: no difference between inmem/mr pick (see related cp instruction), but wrt w/ w/o weights
		//(in contrast to cp instructions, w/o weights does not materializes weights of 1)
		switch( _type ) {
			case VALUEPICK: {
				if( input2.isScalar() ) {
					ScalarObject quantile = ec.getScalarInput(input2);
					double[] wt = getWeightedQuantileSummary(in, mc,
						new double[]{quantile.getDoubleValue()});
					ec.setScalarOutput(output.getName(), new DoubleObject(wt[3]));
				}
				else {
					double[] wt = getWeightedQuantileSummary(in, mc, DataConverter
						.convertToDoubleVector(ec.getMatrixInput(input2.getName())));
					ec.releaseMatrixInput(input2.getName());
					int qlen = wt.length/3;
					MatrixBlock out = new MatrixBlock(qlen,1,false);
					IntStream.range(0, out.getNumRows())
						.forEach(i -> out.quickSetValue(i, 0, wt[2*qlen+i+1]));
					ec.setMatrixOutput(output.getName(), out);
				}
				break;
			}
			
			case MEDIAN: {
				double[] wt = getWeightedQuantileSummary(in, mc, new double[]{0.5});
				ec.setScalarOutput(output.getName(), new DoubleObject(wt[3]));
				break;
			}
			
			case IQM: {
				double[] wt = getWeightedQuantileSummary(in, mc, new double[]{0.25,0.75});
				long key25 = (long)Math.ceil(wt[1]);
				long key75 = (long)Math.ceil(wt[2]);
				JavaPairRDD<MatrixIndexes,MatrixBlock> out = in
					.filter(new FilterFunction(key25+1,key75,mc.getBlocksize()))
					.mapToPair(new ExtractAndSumFunction(key25+1, key75, mc.getBlocksize()));
				double sum = RDDAggregateUtils.sumStable(out).getValue(0, 0);
				double val = MatrixBlock.computeIQMCorrection(
					sum, wt[0], wt[3], wt[5], wt[4], wt[6]);
				ec.setScalarOutput(output.getName(), new DoubleObject(val));
				break;
			}
		
			default:
				throw new DMLRuntimeException("Unsupported qpick operation type: "+_type);
		}
	}
	
	/**
	 * Get a summary of weighted quantiles in in the following form:
	 * sum of weights, (keys of quantiles), (portions of quantiles), (values of quantiles)
	 * 
	 * @param w rdd containing values and optionally weights, sorted by value
	 * @param mc matrix characteristics
	 * @param quantiles one or more quantiles between 0 and 1.
	 * @return a summary of weighted quantiles
	 */
	private static double[] getWeightedQuantileSummary(JavaPairRDD<MatrixIndexes,MatrixBlock> w, DataCharacteristics mc, double[] quantiles)
	{
		double[] ret = new double[3*quantiles.length + 1];
		if( mc.getCols()==2 ) //weighted 
		{
			//sort blocks (values sorted but blocks and partitions are not)
			w = w.sortByKey();
			
			//compute cumsum weights per partition
			//with assumption that partition aggregates fit into memory
			List<Tuple2<Integer,Double>> partWeights = w
				.mapPartitionsWithIndex(new SumWeightsFunction(), false).collect();
			
			//compute sum of weights
			ret[0] = partWeights.stream().mapToDouble(p -> p._2()).sum();
			
			//compute total cumsum and determine partitions
			double[] qdKeys = new double[quantiles.length];
			long[] qiKeys = new long[quantiles.length];
			int[] partitionIDs = new int[quantiles.length];
			double[] offsets = new double[quantiles.length];
			for( int i=0; i<quantiles.length; i++ ) {
				qdKeys[i] = quantiles[i]*ret[0];
				qiKeys[i] = (long)Math.ceil(qdKeys[i]);
			}
			double cumSum = 0;
			for( Tuple2<Integer,Double> psum : partWeights ) {
				double tmp = cumSum + psum._2();
				for(int i=0; i<quantiles.length; i++)
					if( tmp >= qiKeys[i] && partitionIDs[i] == 0 ) {
						partitionIDs[i] = psum._1();
						offsets[i] = cumSum;
					}
				cumSum = tmp;
			}
			
			//get keys and values for quantile cutoffs 
			List<Tuple2<Integer,double[]>> qVals = w
				.mapPartitionsWithIndex(new ExtractWeightedQuantileFunction(
					mc, qdKeys, qiKeys, partitionIDs, offsets), false).collect();
			for( Tuple2<Integer,double[]> qVal : qVals ) {
				ret[qVal._1()+1] = qVal._2()[0];
				ret[qVal._1()+quantiles.length+1] = qVal._2()[1];
				ret[qVal._1()+2*quantiles.length+1] = qVal._2()[2];
			}
		}
		else {
			ret[0] = mc.getRows();
			for( int i=0; i<quantiles.length; i++ ){
				ret[i+1] = quantiles[i] * mc.getRows();
				ret[i+quantiles.length+1] = Math.ceil(ret[i+1])-ret[i+1];
				ret[i+2*quantiles.length+1] = lookupKey(w, 
					(long)Math.ceil(ret[i+1]), mc.getBlocksize());
			}
		}
		
		return ret;
	}

	private static double lookupKey(JavaPairRDD<MatrixIndexes,MatrixBlock> in, long key, int blen) {
		long rix = UtilFunctions.computeBlockIndex(key, blen);
		long pos = UtilFunctions.computeCellInBlock(key, blen);
		List<MatrixBlock> val = in.lookup(new MatrixIndexes(rix,1));
		if( val.isEmpty() )
			throw new DMLRuntimeException("Invalid key lookup in empty list.");
		MatrixBlock tmp = val.get(0);
		if( tmp.getNumRows() <= pos )
			throw new DMLRuntimeException("Invalid key lookup for " +
				pos + " in block of size " + tmp.getNumRows()+"x"+tmp.getNumColumns());
		return val.get(0).quickGetValue((int)pos, 0);
	}
	
	private static class FilterFunction implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> 
	{
		private static final long serialVersionUID = -8249102381116157388L;

		//boundary keys (inclusive)
		private long _minRowIndex;
		private long _maxRowIndex;
		
		public FilterFunction(long key25, long key75, int blen) {
			_minRowIndex = UtilFunctions.computeBlockIndex(key25, blen);
			_maxRowIndex = UtilFunctions.computeBlockIndex(key75, blen);
		}

		@Override
		public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
			throws Exception 
		{
			long rowIndex = arg0._1().getRowIndex();
			return (rowIndex>=_minRowIndex && rowIndex<=_maxRowIndex);
		}
	}

	private static class ExtractAndSumFunction implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = -584044441055250489L;
		
		//boundary keys (inclusive)
		private long _minRowIndex;
		private long _maxRowIndex;
		private int _minPos;
		private int _maxPos;
		
		public ExtractAndSumFunction(long key25, long key75, int blen)
		{
			_minRowIndex = UtilFunctions.computeBlockIndex(key25, blen);
			_maxRowIndex = UtilFunctions.computeBlockIndex(key75, blen);
			_minPos = UtilFunctions.computeCellInBlock(key25, blen);
			_maxPos = UtilFunctions.computeCellInBlock(key75, blen);
		}
		
		@Override
		public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) 
			throws Exception 
		{
			MatrixIndexes ix = arg0._1();
			MatrixBlock mb = arg0._2();
			int rl = (ix.getRowIndex() == _minRowIndex) ? _minPos : 0;
			int ru = (ix.getRowIndex() == _maxRowIndex) ? _maxPos+1 : mb.getNumRows();
			MatrixBlock ret = new MatrixBlock(1,2,false);
			ret.setValue(0, 0, (mb.getNumColumns()==1) ? 
				sum(mb, rl, ru) : sumWeighted(mb, rl, ru));
			return new Tuple2<>(new MatrixIndexes(1,1), ret);
		}
		
		private static double sum(MatrixBlock mb, int rl, int ru) {
			double sum = 0;
			for(int i=rl; i<ru; i++)
				sum += mb.quickGetValue(i, 0);
			return sum;
		}
		
		private static double sumWeighted(MatrixBlock mb, int rl, int ru) {
			double sum = 0;
			for(int i=rl; i<ru; i++)
				sum += mb.quickGetValue(i, 0)
					* mb.quickGetValue(i, 1);
			return sum;
 		}
	}

	private static class SumWeightsFunction implements Function2<Integer,Iterator<Tuple2<MatrixIndexes,MatrixBlock>>,Iterator<Tuple2<Integer, Double>>> 
	{
		private static final long serialVersionUID = 7169831202450745373L;

		@Override
		public Iterator<Tuple2<Integer, Double>> call(Integer v1, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> v2)
			throws Exception 
		{
			//aggregate partition weights (in sorted order)
			double sum = 0;
			while( v2.hasNext() )
				sum += v2.next()._2().sumWeightForQuantile();
			
			//return tuple for partition aggregate
			return Arrays.asList(new Tuple2<>(v1,sum)).iterator();
		}
	}
	
	private static class ExtractWeightedQuantileFunction implements Function2<Integer,Iterator<Tuple2<MatrixIndexes,MatrixBlock>>,Iterator<Tuple2<Integer, double[]>>> 
	{
		private static final long serialVersionUID = 4879975971050093739L;
		private final DataCharacteristics _mc;
		private final double[] _qdKeys;
		private final long[] _qiKeys;
		private final int[] _qPIDs;
		private final double[] _offsets;
		
		public ExtractWeightedQuantileFunction(DataCharacteristics mc, double[] qdKeys, long[] qiKeys, int[] qPIDs, double[] offsets) {
			_mc = mc;
			_qdKeys = qdKeys;
			_qiKeys = qiKeys;
			_qPIDs = qPIDs;
			_offsets = offsets;
		}

		@Override
		public Iterator<Tuple2<Integer, double[]>> call(Integer v1, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> v2) 
			throws Exception 
		{
			//early abort for unnecessary partitions
			if( !ArrayUtils.contains(_qPIDs, v1) )
				return Collections.emptyIterator();
			
			//determine which quantiles are active
			int qlen = (int)Arrays.stream(_qPIDs).filter(i -> i==v1).count();
			int[] qix = new int[qlen];
			for(int i=0, pos=0; i<_qPIDs.length; i++)
				if( _qPIDs[i]==v1 )
					qix[pos++] = i;
			double offset = _offsets[qix[0]];
			
			//iterate over blocks and determine quantile positions
			ArrayList<Tuple2<Integer,double[]>> ret = new ArrayList<>();
			while( v2.hasNext() ) {
				Tuple2<MatrixIndexes, MatrixBlock> tmp = v2.next();
				MatrixIndexes ix = tmp._1();
				MatrixBlock mb = tmp._2();
				for( int i=0; i<mb.getNumRows(); i++ ) {
					double val = mb.quickGetValue(i, 1);
					for( int j=0; j<qlen; j++ ) {
						if( offset+val >= _qiKeys[qix[j]] ) {
							long pos = UtilFunctions.computeCellIndex(ix.getRowIndex(), _mc.getBlocksize(), i);
							double posPart = offset+val - _qdKeys[qix[j]];
							ret.add(new Tuple2<>(qix[j], new double[]{pos, posPart, mb.quickGetValue(i, 0)}));
							_qiKeys[qix[j]] = Long.MAX_VALUE;
						}
					}
					offset += val;
				}
			}
			return ret.iterator();
		}
	}
}