/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator.spark;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.feature.ColumnPruner;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.ResultFeature;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.TargetField;
import scala.collection.immutable.Set;

public class TransformerBuilder {

	private Evaluator evaluator = null;

	private List<ColumnProducer<? extends ResultField>> columnProducers = new ArrayList<>();

	private boolean exploded = false;


	public TransformerBuilder(Evaluator evaluator){
		setEvaluator(evaluator);
	}

	/**
	 * <p>
	 * Appends all target fields.
	 * </p>
	 *
	 * @see Evaluator#getTargetFields()
	 */
	public TransformerBuilder withTargetCols(){
		Evaluator evaluator = getEvaluator();

		List<TargetField> targetFields = evaluator.getTargetFields();
		for(TargetField targetField : targetFields){
			this.columnProducers.add(new TargetColumnProducer(targetField, null));
		}

		return this;
	}

	/**
	 * <p>
	 * Appends all output fields.
	 * </p>
	 *
	 * @see Evaluator#getOutputFields()
	 */
	public TransformerBuilder withOutputCols(){
		Evaluator evaluator = getEvaluator();

		List<OutputField> outputFields = evaluator.getOutputFields();
		for(OutputField outputField : outputFields){
			this.columnProducers.add(new OutputColumnProducer(outputField, null));
		}

		return this;
	}

	/**
	 * <p>
	 * Appends the sole target field of a regression or classification model.
	 * </p>
	 *
	 * @param columnName The name of the target column.
	 */
	public TransformerBuilder withLabelCol(String columnName){
		Evaluator evaluator = getEvaluator();

		TargetField targetField = getTargetField(evaluator);

		this.columnProducers.add(new TargetColumnProducer(targetField, columnName));

		return this;
	}

	public TransformerBuilder withProbabilityCol(String columnName){
		return withProbabilityCol(columnName, null);
	}

	/**
	 * <p>
	 * Appends the probability distribution associated with the sole target field of a classification model.
	 * </p>
	 *
	 * @param columnName The name of the probability column.
	 * @param labels The ordering of class label elements in the vector.
	 */
	public TransformerBuilder withProbabilityCol(String columnName, List<String> labels){
		Evaluator evaluator = getEvaluator();

		TargetField targetField = getTargetField(evaluator);

		List<OutputField> probabilityOutputFields = getProbabilityFields(evaluator, targetField);

		List<String> targetCategories = probabilityOutputFields.stream()
			.map(probabilityOutputField -> {
				org.dmg.pmml.OutputField pmmlOutputField = probabilityOutputField.getField();

				String value = pmmlOutputField.getValue();
				if(value == null){
					throw new MissingAttributeException(pmmlOutputField, PMMLAttributes.OUTPUTFIELD_VALUE);
				}

				return value;
			})
			.collect(Collectors.toList());

		if((labels != null) && (labels.size() != targetCategories.size() || !labels.containsAll(targetCategories))){
			throw new IllegalArgumentException("Model has an incompatible set of probability-type output fields (expected " + labels + ", got " + targetCategories + ")");
		}

		this.columnProducers.add(new ProbabilityColumnProducer(targetField, columnName, labels != null ? labels : targetCategories));

		return this;
	}

	public TransformerBuilder exploded(boolean exploded){
		this.exploded = exploded;

		return this;
	}

	public Transformer build(){
		Evaluator evaluator = getEvaluator();

		PMMLTransformer pmmlTransformer = new PMMLTransformer(evaluator, this.columnProducers);

		if(this.exploded){
			ColumnExploder columnExploder = new ColumnExploder(pmmlTransformer.getOutputCol());

			ColumnPruner columnPruner = new ColumnPruner(new Set.Set1<>(pmmlTransformer.getOutputCol()));

			PipelineModel pipelineModel = new PipelineModel(null, new Transformer[]{pmmlTransformer, columnExploder, columnPruner});

			return pipelineModel;
		}

		return pmmlTransformer;
	}

	private Evaluator getEvaluator(){
		return this.evaluator;
	}

	private void setEvaluator(Evaluator evaluator){
		this.evaluator = evaluator;
	}

	static
	private TargetField getTargetField(Evaluator evaluator){
		List<TargetField> targetFields = evaluator.getTargetFields();

		if(targetFields.size() < 1){
			throw new IllegalArgumentException("Model does not have a target field");
		} else

		if(targetFields.size() > 1){
			throw new IllegalArgumentException("Model has multiple target fields (" + targetFields + ")");
		}

		return targetFields.get(0);
	}

	static
	private List<OutputField> getProbabilityFields(Evaluator evaluator, TargetField targetField){
		List<OutputField> outputFields = evaluator.getOutputFields();

		Predicate<OutputField> predicate = new Predicate<OutputField>(){

			@Override
			public boolean test(OutputField outputField){
				org.dmg.pmml.OutputField pmmlOutputField = outputField.getField();

				ResultFeature resultFeature = pmmlOutputField.getResultFeature();
				switch(resultFeature){
					case PROBABILITY:
						FieldName targetFieldName = pmmlOutputField.getTargetField();

						return Objects.equals(targetFieldName, null) || Objects.equals(targetFieldName, targetField.getName());
					default:
						return false;
				}
			}
		};

		List<OutputField> probabilityOutputFields = outputFields.stream()
			.filter(predicate)
			.collect(Collectors.toList());

		if(probabilityOutputFields.size() < 1){
			throw new IllegalArgumentException("Model does not have probability-type output fields");
		}

		return probabilityOutputFields;
	}
}