/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.google.common.collect.ArrayTable;
import com.google.common.collect.Lists;
import com.google.common.collect.Table;
import org.dmg.pmml.Aggregate;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldColumnPair;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.TextIndex.CountHits;
import org.dmg.pmml.TextIndexNormalization;
import org.jpmml.evaluator.functions.EchoFunction;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public class ExpressionUtilTest {

	@Test
	public void evaluateConstant(){
		Constant emptyString = new Constant()
			.setDataType(DataType.STRING);

		assertEquals("", evaluate(emptyString));

		emptyString.setMissing(true);

		assertEquals(null, evaluate(emptyString));

		Constant stringThree = new Constant("3")
			.setDataType(DataType.STRING);

		assertEquals("3", evaluate(stringThree));

		stringThree.setMissing(true);

		assertEquals(null, evaluate(stringThree));

		Constant integerThree = new Constant("3")
			.setDataType(DataType.INTEGER);

		assertEquals(3, evaluate(integerThree));

		integerThree.setMissing(true);

		assertEquals(null, evaluate(integerThree));

		Constant floatThree = new Constant("3")
			.setDataType(DataType.FLOAT);

		assertEquals(3f, evaluate(floatThree));

		floatThree.setMissing(true);

		assertEquals(null, evaluate(floatThree));

		Constant doubleThree = new Constant("3")
			.setDataType(DataType.DOUBLE);

		assertEquals(3d, evaluate(doubleThree));

		doubleThree.setMissing(true);

		assertEquals(null, evaluate(doubleThree));
	}

	@Test
	public void evaluateConstantNaN(){
		Constant constant = new Constant("NaN");

		assertEquals(Double.NaN, evaluate(constant));

		constant.setDataType(DataType.FLOAT);

		assertEquals(Float.NaN, evaluate(constant));
	}

	@Test
	public void evaluateFieldRef(){
		FieldName name = FieldName.create("x");

		FieldRef fieldRef = new FieldRef(name);

		assertEquals("3", evaluate(fieldRef, name, "3"));
		assertEquals(null, evaluate(fieldRef, name, null));

		fieldRef.setMapMissingTo("Missing");

		assertEquals("Missing", evaluate(fieldRef, name, null));
	}

	@Test
	public void evaluateNormContinuous(){
		FieldName name = FieldName.create("x");

		NormContinuous normContinuous = new NormContinuous(name, null)
			.setMapMissingTo(5d);

		assertEquals(5d, evaluate(normContinuous, name, null));
	}

	@Test
	public void evaluateNormDiscrete(){
		FieldName name = FieldName.create("x");

		Double equals = 1d;
		Double notEquals = 0d;

		NormDiscrete stringThree = new NormDiscrete(name, "3");

		assertEquals(equals, evaluate(stringThree, name, "3"));
		assertEquals(notEquals, evaluate(stringThree, name, "1"));

		stringThree.setMapMissingTo(5d);

		assertEquals(5d, evaluate(stringThree, name, null));

		NormDiscrete integerThree = new NormDiscrete(name, "3");

		assertEquals(equals, evaluate(integerThree, name, 3));
		assertEquals(notEquals, evaluate(integerThree, name, 1));

		NormDiscrete floatThree = new NormDiscrete(name, "3.0");

		assertEquals(equals, evaluate(floatThree, name, 3f));
		assertEquals(notEquals, evaluate(floatThree, name, 1f));
	}

	@Test
	public void evaluateDiscretize(){
		FieldName name = FieldName.create("x");

		Discretize discretize = new Discretize(name);

		assertEquals(null, evaluate(discretize, name, null));

		discretize.setMapMissingTo("Missing");

		assertEquals("Missing", evaluate(discretize, name, null));
		assertEquals(null, evaluate(discretize, name, 3));

		discretize.setDefaultValue("Default");

		assertEquals("Default", evaluate(discretize, name, 3));
	}

	@Test
	public void evaluateMapValues(){
		FieldName name = FieldName.create("x");

		List<List<String>> rows = Arrays.asList(
			Arrays.asList("0", "zero"),
			Arrays.asList("1", "one")
		);

		MapValues mapValues = new MapValues("data:output", null, createInlineTable(rows, Arrays.asList("data:input", "data:output")))
			.addFieldColumnPairs(new FieldColumnPair(name, "data:input"));

		assertEquals("zero", evaluate(mapValues, name, "0"));
		assertEquals("one", evaluate(mapValues, name, "1"));
		assertEquals(null, evaluate(mapValues, name, "3"));

		assertEquals(null, evaluate(mapValues, name, null));

		mapValues.setMapMissingTo("Missing");

		assertEquals("Missing", evaluate(mapValues, name, null));

		mapValues.setDefaultValue("Default");

		assertEquals("Default", evaluate(mapValues, name, "3"));
	}

	@Test
	public void evaluateTextIndex(){
		FieldName name = FieldName.create("x");

		TextIndex textIndex = new TextIndex(name, new Constant("user friendly"))
			.setWordSeparatorCharacterRE("[\\s\\-]");

		assertEquals(null, evaluate(textIndex, name, null));

		assertEquals(1, evaluate(textIndex, name, "user friendly"));
		assertEquals(1, evaluate(textIndex, name, "user-friendly"));

		textIndex = new TextIndex(name, new Constant("brown fox"));

		String text = "The quick browny foxy jumps over the lazy dog. The brown fox runs away and to be with another brown foxy.";

		textIndex.setMaxLevenshteinDistance(0);

		assertEquals(1, evaluate(textIndex, name, text));

		textIndex.setMaxLevenshteinDistance(1);

		assertEquals(2, evaluate(textIndex, name, text));

		textIndex.setMaxLevenshteinDistance(2);

		assertEquals(3, evaluate(textIndex, name, text));

		textIndex = new TextIndex(name, new Constant("dog"))
			.setMaxLevenshteinDistance(1);

		text = "I have a doog. My dog is white. The doog is friendly.";

		textIndex.setCountHits(CountHits.ALL_HITS);

		assertEquals(3, evaluate(textIndex, name, text));

		textIndex.setCountHits(CountHits.BEST_HITS);

		assertEquals(1, evaluate(textIndex, name, text));

		textIndex = new TextIndex(name, new Constant("sun"))
			.setCaseSensitive(false);

		text = "The Sun was setting while the captain's son reached the bounty island, minutes after their ship had sunk to the bottom of the ocean.";

		textIndex.setMaxLevenshteinDistance(0);

		assertEquals(1, evaluate(textIndex, name, text));

		textIndex.setMaxLevenshteinDistance(1);

		assertEquals(3, evaluate(textIndex, name, text));
	}

	@Test
	public void evaluateTextIndexNormalization(){
		FieldName name = FieldName.create("x");

		TextIndexNormalization stepOne = new TextIndexNormalization();

		List<List<String>> cells = Arrays.asList(
			Arrays.asList("interfaces?", "interface", "true"),
			Arrays.asList("is|are|seem(ed|s?)|were", "be", "true"),
			Arrays.asList("user friendl(y|iness)", "user_friendly", "true")
		);

		stepOne.setInlineTable(createInlineTable(cells, stepOne));

		TextIndexNormalization stepTwo = new TextIndexNormalization()
			.setInField("re")
			.setOutField("feature");

		cells = Arrays.asList(
			Arrays.asList("interface be (user_friendly|well designed|excellent)", "ui_good", "true")
		);

		stepTwo.setInlineTable(createInlineTable(cells, stepTwo));

		TextIndex textIndex = new TextIndex(name, new Constant("ui_good"))
			.setLocalTermWeights(TextIndex.LocalTermWeights.BINARY)
			.setCaseSensitive(false)
			.addTextIndexNormalizations(stepOne, stepTwo);

		assertEquals(1, evaluate(textIndex, name, "Testing the app for a few days convinced me the interfaces are excellent!"));
	}

	@Test
	public void evaluateApply(){
		FieldName name = FieldName.create("x");

		Apply apply = new Apply(PMMLFunctions.DIVIDE)
			.addExpressions(new FieldRef(name), new Constant("0"));

		assertEquals(null, evaluate(apply, name, null));

		apply.setDefaultValue("-1");

		assertEquals("-1", evaluate(apply, name, null));

		apply.setMapMissingTo("missing");

		assertEquals("missing", evaluate(apply, name, null));

		apply.setInvalidValueTreatment(InvalidValueTreatmentMethod.RETURN_INVALID);

		try {
			evaluate(apply, name, 1);

			fail();
		} catch(InvalidResultException ire){
			// Ignored
		}

		apply.setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_IS);

		try {
			evaluate(apply, name, 1);

			fail();
		} catch(InvalidResultException ire){
			// Ignored
		}

		apply.setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING);

		assertEquals("-1", evaluate(apply, name, 1));
	}

	@Test
	public void evaluateApplyCondition(){
		FieldName name = FieldName.create("x");

		Apply condition = new Apply(PMMLFunctions.ISNOTMISSING)
			.addExpressions(new FieldRef(name));

		Apply apply = new Apply(PMMLFunctions.IF)
			.addExpressions(condition);

		try {
			evaluate(apply, name, null);

			fail();
		} catch(FunctionException fe){
			// Ignored
		}

		Expression thenPart = new Apply(PMMLFunctions.ABS)
			.addExpressions(new FieldRef(name));

		apply.addExpressions(thenPart);

		assertEquals(1, evaluate(apply, name, 1));
		assertEquals(1, evaluate(apply, name, -1));

		assertEquals(null, evaluate(apply, name, null));

		Expression elsePart = new Constant("-1")
			.setDataType(DataType.DOUBLE);

		apply.addExpressions(elsePart);

		assertEquals(-1d, evaluate(apply, name, null));

		apply.addExpressions(new FieldRef(name));

		try {
			evaluate(apply, name, null);

			fail();
		} catch(FunctionException fe){
			// Ignored
		}
	}

	@Test
	public void evaluateApplyJavaFunction(){
		FieldName name = FieldName.create("x");

		FieldRef fieldRef = new FieldRef(name);

		Apply apply = new Apply(EchoFunction.class.getName())
			.addExpressions(fieldRef);

		try {
			evaluate(apply);

			fail();
		} catch(EvaluationException ee){
			assertEquals(fieldRef, ee.getContext());
		}

		assertEquals("Hello World!", evaluate(apply, name, "Hello World!"));
	}

	@Test
	public void evaluateAggregateArithmetic(){
		FieldName name = FieldName.create("x");

		List<Integer> values = Arrays.asList(1, 2, 3);

		Aggregate aggregate = new Aggregate(name, Aggregate.Function.COUNT);

		assertEquals(3, evaluate(aggregate, name, values));

		aggregate.setFunction(Aggregate.Function.SUM);

		assertEquals(6, evaluate(aggregate, name, values));

		aggregate.setFunction(Aggregate.Function.AVERAGE);

		assertEquals(2d, evaluate(aggregate, name, values));
	}

	@Test
	public void evaluateAggregate(){
		FieldName name = FieldName.create("x");

		TypeInfo typeInfo = new SimpleTypeInfo(DataType.DATE, OpType.ORDINAL);

		List<?> values = Arrays.asList(TypeUtil.parse(DataType.DATE, "2013-01-01"), TypeUtil.parse(DataType.DATE, "2013-02-01"), TypeUtil.parse(DataType.DATE, "2013-03-01"));

		Map<FieldName, FieldValue> arguments = Collections.singletonMap(name, FieldValue.create(typeInfo, values));

		Aggregate aggregate = new Aggregate(name, Aggregate.Function.COUNT);

		assertEquals(3, evaluate(aggregate, arguments));

		aggregate.setFunction(Aggregate.Function.MIN);

		assertEquals(values.get(0), evaluate(aggregate, arguments));

		aggregate.setFunction(Aggregate.Function.MAX);

		assertEquals(values.get(2), evaluate(aggregate, arguments));

		typeInfo = new SimpleTypeInfo(DataType.DATE, OpType.ORDINAL, Lists.reverse(values));

		arguments = Collections.singletonMap(name, FieldValue.create(typeInfo, values));

		aggregate.setFunction(Aggregate.Function.MIN);

		assertEquals(values.get(2), evaluate(aggregate, arguments));

		aggregate.setFunction(Aggregate.Function.MAX);

		assertEquals(values.get(0), evaluate(aggregate, arguments));
	}

	static
	InlineTable createInlineTable(List<List<String>> rows, TextIndexNormalization textIndexNormalization){
		return createInlineTable(rows, Arrays.asList(textIndexNormalization.getInField(), textIndexNormalization.getOutField(), textIndexNormalization.getRegexField()));
	}

	static
	InlineTable createInlineTable(List<List<String>> rows, List<String> columns){
		List<Integer> rowKeys = new ArrayList<>();

		for(int i = 0; i < rows.size(); i++){
			rowKeys.add(i + 1);
		}

		Table<Integer, String, String> table = ArrayTable.create(rowKeys, columns);

		for(int i = 0; i < rows.size(); i++){
			List<String> row = rows.get(i);

			for(int j = 0; j < columns.size(); j++){
				String column = columns.get(j);
				String value = row.get(j);

				if(value == null){
					continue;
				}

				table.put(rowKeys.get(i), column, value);
			}
		}

		return InlineTableUtil.format(table);
	}

	static
	private Object evaluate(Expression expression, Object... objects){
		Map<FieldName, ?> arguments = ModelEvaluatorTest.createArguments(objects);

		return evaluate(expression, arguments);
	}

	static
	private Object evaluate(Expression expression, Map<FieldName, ?> arguments){
		EvaluationContext context = new VirtualEvaluationContext();
		context.declareAll(arguments);

		FieldValue result = ExpressionUtil.evaluate(expression, context);

		return FieldValueUtil.getValue(result);
	}
}