/*
 * Modifications Copyright 2019 Graz University of Technology
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
 
package org.tugraz.sysds.test.component.convert;

import static org.junit.Assert.assertTrue;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.tugraz.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.tugraz.sysds.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.tugraz.sysds.test.AutomatedTestBase;

public class RDDConverterUtilsExtTest extends AutomatedTestBase {

	private static SparkConf conf;
	private static JavaSparkContext sc;

	@BeforeClass
	public static void setUpClass() {
		if (conf == null)
			conf = SparkExecutionContext.createSystemDSSparkConf().setAppName("RDDConverterUtilsExtTest")
					.setMaster("local");
		if (sc == null)
			sc = new JavaSparkContext(conf);
	}

	@Override
	public void setUp() {
		// no setup required
	}

	/**
	 * Convert a basic String to a spark.sql.Row.
	 */
	static class StringToRow implements Function<String, Row> {
		private static final long serialVersionUID = 3945939649355731805L;

		@Override
		public Row call(String str) throws Exception {
			return RowFactory.create(str);
		}
	}

	@Test
	public void testStringDataFrameToVectorDataFrame() {
		List<String> list = new ArrayList<>();
		list.add("((1.2, 4.3, 3.4))");
		list.add("(1.2, 3.4, 2.2)");
		list.add("[[1.2, 34.3, 1.2, 1.25]]");
		list.add("[1.2, 3.4]");
		JavaRDD<String> javaRddString = sc.parallelize(list);
		JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
		List<StructField> fields = new ArrayList<>();
		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
		StructType schema = DataTypes.createStructType(fields);
		Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
		Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);

		List<String> expectedResults = new ArrayList<>();
		expectedResults.add("[[1.2,4.3,3.4]]");
		expectedResults.add("[[1.2,3.4,2.2]]");
		expectedResults.add("[[1.2,34.3,1.2,1.25]]");
		expectedResults.add("[[1.2,3.4]]");

		List<Row> outputList = outDF.collectAsList();
		for (Row row : outputList) {
			assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
		}
	}

	@Test
	public void testStringDataFrameToVectorDataFrameNull() {
		List<String> list = new ArrayList<>();
		list.add("[1.2, 3.4]");
		list.add(null);
		JavaRDD<String> javaRddString = sc.parallelize(list);
		JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
		List<StructField> fields = new ArrayList<>();
		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
		StructType schema = DataTypes.createStructType(fields);
		Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
		Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);

		List<String> expectedResults = new ArrayList<>();
		expectedResults.add("[[1.2,3.4]]");
		expectedResults.add("[null]");

		List<Row> outputList = outDF.collectAsList();
		for (Row row : outputList) {
			assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
		}
	}

	@Test(expected = SparkException.class)
	public void testStringDataFrameToVectorDataFrameNonNumbers() {
		List<String> list = new ArrayList<>();
		list.add("[cheeseburger,fries]");
		JavaRDD<String> javaRddString = sc.parallelize(list);
		JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
		List<StructField> fields = new ArrayList<>();
		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
		StructType schema = DataTypes.createStructType(fields);
		Dataset<Row> inDF = sparkSession.createDataFrame(javaRddRow, schema);
		Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sparkSession, inDF);
		// trigger evaluation to throw exception
		outDF.collectAsList();
	}

	@After
	@Override
	public void tearDown() {
		super.tearDown();
	}

	@AfterClass
	public static void tearDownClass() {
		// stop spark context to allow single jvm tests (otherwise the
		// next test that tries to create a SparkContext would fail)
		sc.stop();
		sc = null;
		conf = null;
	}
}