/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

public class RPartConverter extends TreeModelConverter<RGenericVector> {

	private int useSurrogate = 0;

	private Formula formula = null;


	public RPartConverter(RGenericVector rpart){
		super(rpart);

		RGenericVector control = rpart.getGenericElement("control");

		RNumberVector<?> useSurrogate = control.getNumericElement("usesurrogate");

		this.useSurrogate = ValueUtil.asInt(useSurrogate.asScalar());

		switch(this.useSurrogate){
			case 0:
			case 1:
			case 2:
				break;
			default:
				throw new IllegalArgumentException();
		}
	}

	public boolean hasScoreDistribution(){
		return true;
	}

	@Override
	public void encodeSchema(RExpEncoder encoder){
		RGenericVector rpart = getObject();

		RGenericVector frame = rpart.getGenericElement("frame");
		RExp terms = rpart.getElement("terms");

		RGenericVector xlevels = rpart.getGenericAttribute("xlevels", false);
		RStringVector ylevels = rpart.getStringAttribute("ylevels", false);

		RIntegerVector var = frame.getFactorElement("var");

		FormulaContext context = new XLevelsFormulaContext(xlevels);

		Formula formula = FormulaUtil.createFormula(terms, context, encoder);

		FormulaUtil.setLabel(formula, terms, ylevels, encoder);

		List<String> names = FormulaUtil.removeSpecialSymbol(RExpUtil.getFactorLevels(var), "<leaf>", 0);

		FormulaUtil.addFeatures(formula, names, false, encoder);

		this.formula = formula;
	}

	@Override
	public TreeModel encodeModel(Schema schema){
		RGenericVector rpart = getObject();

		RGenericVector frame = rpart.getGenericElement("frame");
		RStringVector method = rpart.getStringElement("method");
		RNumberVector<?> splits = rpart.getNumericElement("splits");
		RIntegerVector csplit = rpart.getIntegerElement("csplit", false);

		RIntegerVector var = frame.getIntegerElement("var");
		RIntegerVector n = frame.getIntegerElement("n");
		RIntegerVector ncompete = frame.getIntegerElement("ncompete");
		RIntegerVector nsurrogate = frame.getIntegerElement("nsurrogate");

		RIntegerVector rowNames = frame.getIntegerAttribute("row.names");

		if((rowNames.getValues()).indexOf(Integer.MIN_VALUE) > -1){
			throw new IllegalArgumentException();
		}

		int[][] splitInfo = new int[1 + rowNames.size()][3];

		for(int offset = 0; offset < rowNames.size(); offset++){
			splitInfo[offset][1] = ncompete.getValue(offset);
			splitInfo[offset][2] = nsurrogate.getValue(offset);

			splitInfo[offset + 1][0] = splitInfo[offset][0] + splitInfo[offset][1] + splitInfo[offset][2] + (var.getValue(offset) != 1 ? 1 : 0);
		}

		switch(method.asScalar()){
			case "anova":
				return encodeRegression(frame, rowNames, var, n, splitInfo, splits, csplit, schema);
			case "class":
				return encodeClassification(frame, rowNames, var, n, splitInfo, splits, csplit, schema);
			default:
				throw new IllegalArgumentException();
		}
	}

	private TreeModel encodeRegression(RGenericVector frame, RIntegerVector rowNames, RIntegerVector var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema){
		RNumberVector<?> yval = frame.getNumericElement("yval");

		ScoreEncoder scoreEncoder = new ScoreEncoder(){

			@Override
			public Node encode(Node node, int offset){
				Number score = yval.getValue(offset);
				Number recordCount = n.getValue(offset);

				node
					.setScore(score)
					.setRecordCount(recordCount);

				return node;
			}
		};

		Node root = encodeNode(True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);

		TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root);

		return configureTreeModel(treeModel);
	}

	private TreeModel encodeClassification(RGenericVector frame, RIntegerVector rowNames, RIntegerVector var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema){
		RDoubleVector yval2 = frame.getDoubleElement("yval2");

		CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

		List<?> categories = categoricalLabel.getValues();

		boolean hasScoreDistribution = hasScoreDistribution();

		ScoreEncoder scoreEncoder = new ScoreEncoder(){

			private List<Integer> classes = null;

			private List<List<? extends Number>> recordCounts = null;


			{
				int rows = rowNames.size();
				int columns = 1 + (2 * categories.size()) + 1;

				List<Integer> classes = ValueUtil.asIntegers(FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 0));

				this.classes = new ArrayList<>(classes);

				if(hasScoreDistribution){
					this.recordCounts = new ArrayList<>();

					for(int i = 0; i < categories.size(); i++){
						List<? extends Number> recordCounts = FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 1 + i);

						this.recordCounts.add(new ArrayList<>(recordCounts));
					}
				}
			}

			@Override
			public Node encode(Node node, int offset){
				Object score = categories.get(this.classes.get(offset) - 1);
				Integer recordCount = n.getValue(offset);

				node
					.setScore(score)
					.setRecordCount(recordCount);

				if(hasScoreDistribution){
					node = new ClassifierNode(node);

					List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();

					for(int i = 0; i < categories.size(); i++){
						List<? extends Number> recordCounts = this.recordCounts.get(i);

						ScoreDistribution scoreDistribution = new ScoreDistribution()
							.setValue(categories.get(i))
							.setRecordCount(recordCounts.get(offset));

						scoreDistributions.add(scoreDistribution);
					}
				}

				return node;
			}
		};

		Node root = encodeNode(True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);

		TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()), root);

		if(hasScoreDistribution){
			treeModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
		}

		return configureTreeModel(treeModel);
	}

	private TreeModel configureTreeModel(TreeModel treeModel){
		TreeModel.NoTrueChildStrategy noTrueChildStrategy = TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION;
		TreeModel.MissingValueStrategy missingValueStrategy;

		switch(this.useSurrogate){
			case 0:
				missingValueStrategy = TreeModel.MissingValueStrategy.NULL_PREDICTION; // XXX
				break;
			case 1:
				missingValueStrategy = TreeModel.MissingValueStrategy.LAST_PREDICTION;
				break;
			case 2:
				missingValueStrategy = null;
				break;
			default:
				throw new IllegalArgumentException();
		}

		treeModel
			.setNoTrueChildStrategy(noTrueChildStrategy)
			.setMissingValueStrategy(missingValueStrategy);

		return treeModel;
	}

	private Node encodeNode(Predicate predicate, int rowName, RIntegerVector rowNames, RIntegerVector var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, ScoreEncoder scoreEncoder, Schema schema){
		int offset = getIndex(rowNames, rowName);

		Integer id = Integer.valueOf(rowName);

		int splitVar = var.getValue(offset) - 1;
		if(splitVar == 0){
			Node result = new CountingLeafNode(null, predicate)
				.setId(id);

			return scoreEncoder.encode(result, offset);
		}

		int leftRowName = rowName * 2;
		int rightRowName = (rowName * 2) + 1;

		Integer majorityDir = null;

		if(this.useSurrogate == 2){
			int leftOffset = getIndex(rowNames, leftRowName);
			int rightOffset = getIndex(rowNames, rightRowName);

			majorityDir = Double.compare(n.getValue(leftOffset), n.getValue(rightOffset));
		}

		Feature feature = schema.getFeature(splitVar - 1);

		int splitOffset = splitInfo[offset][0];

		int splitNumCompete = splitInfo[offset][1];
		int splitNumSurrogate = splitInfo[offset][2];

		List<Predicate> predicates = encodePredicates(feature, splitOffset, splits, csplit);

		Predicate leftPredicate = predicates.get(0);
		Predicate rightPredicate = predicates.get(1);

		if(this.useSurrogate > 0 && splitNumSurrogate > 0){
			CompoundPredicate leftCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null)
				.addPredicates(leftPredicate);

			CompoundPredicate rightCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null)
				.addPredicates(rightPredicate);

			RStringVector splitRowNames = splits.dimnames(0);

			for(int i = 0; i < splitNumSurrogate; i++){
				int surrogateSplitOffset = (splitOffset + 1) + splitNumCompete + i;

				feature = getFeature(FieldName.create(splitRowNames.getValue(surrogateSplitOffset)));

				predicates = encodePredicates(feature, surrogateSplitOffset, splits, csplit);

				leftCompoundPredicate.addPredicates(predicates.get(0));
				rightCompoundPredicate.addPredicates(predicates.get(1));
			}

			leftPredicate = leftCompoundPredicate;
			rightPredicate = rightCompoundPredicate;
		}

		Node leftChild = encodeNode(leftPredicate, leftRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
		Node rightChild = encodeNode(rightPredicate, rightRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);

		if(this.useSurrogate == 2){

			if(majorityDir < 0){
				makeDefault(rightChild);
			} else

			if(majorityDir > 0){
				Node tmp = leftChild;

				makeDefault(leftChild);

				leftChild = rightChild;
				rightChild = tmp;
			}
		}

		Node result = new CountingBranchNode(null, predicate)
			.setId(id)
			.addNodes(leftChild, rightChild);

		return scoreEncoder.encode(result, offset);
	}

	private List<Predicate> encodePredicates(Feature feature, int splitOffset, RNumberVector<?> splits, RIntegerVector csplit){
		Predicate leftPredicate;
		Predicate rightPredicate;

		RIntegerVector splitsDim = splits.dim();

		int splitRows = splitsDim.getValue(0);
		int splitColumns = splitsDim.getValue(1);

		List<? extends Number> ncat = FortranMatrixUtil.getColumn(splits.getValues(), splitRows, splitColumns, 1);
		List<? extends Number> index = FortranMatrixUtil.getColumn(splits.getValues(), splitRows, splitColumns, 3);

		int splitType = ValueUtil.asInt(ncat.get(splitOffset));

		Number splitValue = index.get(splitOffset);

		if(Math.abs(splitType) == 1){
			SimplePredicate.Operator leftOperator;
			SimplePredicate.Operator rightOperator;

			if(splitType == -1){
				leftOperator = SimplePredicate.Operator.LESS_THAN;
				rightOperator = SimplePredicate.Operator.GREATER_OR_EQUAL;
			} else

			{
				leftOperator = SimplePredicate.Operator.GREATER_OR_EQUAL;
				rightOperator = SimplePredicate.Operator.LESS_THAN;
			}

			leftPredicate = createSimplePredicate(feature, leftOperator, splitValue);
			rightPredicate = createSimplePredicate(feature, rightOperator, splitValue);
		} else

		{
			CategoricalFeature categoricalFeature = (CategoricalFeature)feature;

			RIntegerVector csplitDim = csplit.dim();

			int csplitRows = csplitDim.getValue(0);
			int csplitColumns = csplitDim.getValue(1);

			List<Integer> csplitRow = FortranMatrixUtil.getRow(csplit.getValues(), csplitRows, csplitColumns, ValueUtil.asInt(splitValue) - 1);

			List<?> values = categoricalFeature.getValues();

			leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, csplitRow, 1));
			rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, csplitRow, 3));
		}

		return Arrays.asList(leftPredicate, rightPredicate);
	}

	private void makeDefault(Node node){
		Predicate predicate = node.getPredicate();

		CompoundPredicate compoundPredicate;

		if(predicate instanceof CompoundPredicate){
			compoundPredicate = (CompoundPredicate)predicate;
		} else

		{
			compoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null)
				.addPredicates(predicate);

			node.setPredicate(compoundPredicate);
		}

		compoundPredicate.addPredicates(True.INSTANCE);
	}

	private Feature getFeature(FieldName name){
		return this.formula.resolveFeature(name);
	}

	static
	private int getIndex(RIntegerVector rowNames, int rowName){
		int index = rowNames.indexOf(rowName);
		if(index < 0){
			throw new IllegalArgumentException();
		}

		return index;
	}

	static
	private <E> List<E> selectValues(List<E> values, List<Integer> valueFlags, int flag){
		List<E> result = new ArrayList<>(values.size());

		for(int i = 0; i < values.size(); i++){
			E value = values.get(i);
			Integer valueFlag = valueFlags.get(i);

			if(valueFlag == flag){
				result.add(value);
			}
		}

		return result;
	}

	static
	private interface ScoreEncoder {

		Node encode(Node node, int offset);
	}
}