package com.clust4j.data;

import static org.junit.Assert.*;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.junit.Test;

import com.clust4j.TestSuite;
import com.clust4j.algo.KMeans;
import com.clust4j.algo.KMeansParameters;
import com.clust4j.metrics.scoring.SupervisedMetric;

public class TrainTestSplitTests {

	@Test
	public void testIris() {
		TrainTestSplit split = new TrainTestSplit(TestSuite.IRIS_DATASET, 0.7);
		
		DataSet train = split.getTrain();
		DataSet test = split.getTest();
		
		assertTrue(train.numRows() + test.numRows() == TestSuite.IRIS_DATASET.numRows());
	}

	@Test
	public void testLowerBoundWithLabels() {
		DataSet set = new DataSet(
			new Array2DRowRealMatrix(new double[][]{
				new double[]{0,0,0},
				new double[]{1,1,1}
			}, false),
			new int[]{0,0}
		);
		
		TrainTestSplit split = new TrainTestSplit(set, 0.8);
		assertTrue(split.getTrain().numRows() == 1);
		assertTrue(split.getTest().numRows() == 1);
		
		split = new TrainTestSplit(set, 0.1);
		assertTrue(split.getTrain().numRows() == 1);
		assertTrue(split.getTest().numRows() == 1);
	}
	
	@Test
	public void testLowerBoundWithNoLabels() {
		int[] labels = null;
		DataSet set = new DataSet(
			new Array2DRowRealMatrix(new double[][]{
				new double[]{0,0,0},
				new double[]{1,1,1}
			}, false),
			
			labels // null
		);
		
		TrainTestSplit split = new TrainTestSplit(set, 0.8);
		assertTrue(split.getTrain().numRows() == 1);
		assertTrue(split.getTest().numRows() == 1);
		
		split = new TrainTestSplit(set, 0.1);
		assertTrue(split.getTrain().numRows() == 1);
		assertTrue(split.getTest().numRows() == 1);
	}
	
	@Test
	public void testExceptions() {
		DataSet set = new DataSet(
			new Array2DRowRealMatrix(new double[][]{
				new double[]{0,0,0}
			}, false),
			new int[]{1}
		);
		
		/*
		 * Test one row fails
		 */
		boolean a= false;
		try {
			new TrainTestSplit(set, 0.5);
		} catch(IllegalArgumentException i) {
			a = true;
		} finally {
			assertTrue(a);
		}
		
		
		// re-assign
		set = new DataSet(
			new Array2DRowRealMatrix(new double[][]{
				new double[]{0,0,0},
				new double[]{1,1,1}
			}, false),
			new int[]{1,1}
		);
		
		/*
		 * test fails on 1.0+
		 */
		a= false;
		try {
			new TrainTestSplit(set, 1.0);
		} catch(IllegalArgumentException i) {
			a = true;
		} finally {
			assertTrue(a);
		}
		
		a= false;
		try {
			new TrainTestSplit(set, 1.1);
		} catch(IllegalArgumentException i) {
			a = true;
		} finally {
			assertTrue(a);
		}
		
		/*
		 * test fails on 0.0-
		 */
		a= false;
		try {
			new TrainTestSplit(set, 0.0);
		} catch(IllegalArgumentException i) {
			a = true;
		} finally {
			assertTrue(a);
		}
		
		a= false;
		try {
			new TrainTestSplit(set,-0.1);
		} catch(IllegalArgumentException i) {
			a = true;
		} finally {
			assertTrue(a);
		}
	}
	
	@Test
	public void testOnModel() {
		TrainTestSplit split = new TrainTestSplit(TestSuite.IRIS_DATASET, 0.75);
		DataSet train = split.getTrain();
		DataSet test  = split.getTest();
		
		KMeans model = new KMeansParameters(3).fitNewModel(train.getData());
		int[] predictions = model.predict(test.getData());
		
		// examine affinity:
		System.out.println("Affinity: " + SupervisedMetric.INDEX_AFFINITY.evaluate(test.getLabels(), predictions));
	}
}