package collabfilter;

import static com.datastax.spark.connector.japi.CassandraJavaUtil.javaFunctions;

import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;

import scala.Tuple2;

import com.datastax.spark.connector.cql.CassandraConnector;
import com.datastax.spark.connector.japi.CassandraRow;
import com.datastax.spark.connector.japi.rdd.CassandraJavaRDD;

/**
 * This implementation complies with Java 8. The main difference is the use of
 * lambdas instead of explicit classes for Spark functions.
 */
public class CollabFilterCassandra8 implements ICollabFilterCassandra {
	private static int ITER = 20;
	private static int RANK = 6;
	private static double LAMBDA = 0.01;

	public MatrixFactorizationModel train(JavaSparkContext sparkCtx, CassandraConnector cassandraConnector) {
		CassandraJavaRDD<CassandraRow> trainingRdd = javaFunctions(sparkCtx).cassandraTable(RatingDO.EMPLOYERRATINGS_KEYSPACE, RatingDO.RATINGS_TABLE);
		JavaRDD<Rating> trainingJavaRdd = trainingRdd.map(trainingRow -> new Rating(trainingRow.getInt(RatingDO.USER_COL), trainingRow.getInt(RatingDO.PRODUCT_COL), trainingRow.getDouble(RatingDO.RATING_COL)));
		MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainingJavaRdd), RANK, ITER, LAMBDA);
		return model;
	}

	public JavaRDD<Rating> predict(MatrixFactorizationModel model, CassandraJavaRDD<CassandraRow> validationsCassRdd) {
		RDD<Tuple2<Object, Object>> validationsRdd = JavaRDD.toRDD(validationsCassRdd.map(validationRow -> new Tuple2<Object, Object>(validationRow.getInt(RatingDO.USER_COL), validationRow.getInt(RatingDO.PRODUCT_COL))));
		JavaRDD<Rating> predictionJavaRdd = model.predict(validationsRdd).toJavaRDD();
		return predictionJavaRdd;
	}

	public double validate(JavaRDD<Rating> predictionJavaRdd, CassandraJavaRDD<CassandraRow> validationsCassRdd) {
		JavaPairRDD<Tuple2<Integer, Integer>, Double> predictionsJavaPairs = JavaPairRDD.fromJavaRDD(predictionJavaRdd.map(pred -> new Tuple2<Tuple2<Integer, Integer>, Double>(new Tuple2<Integer, Integer>(pred.user(), pred.product()), pred.rating())));
		JavaRDD<Rating> validationRatings = validationsCassRdd.map(validation -> new Rating(validation.getInt(RatingDO.USER_COL), validation.getInt(RatingDO.PRODUCT_COL), validation.getInt(RatingDO.RATING_COL)));
		JavaRDD<Tuple2<Double, Double>> validationAndPredictions = JavaPairRDD.fromJavaRDD(validationRatings.map(validationRating -> new Tuple2<Tuple2<Integer, Integer>, Double>(new Tuple2<Integer, Integer>(validationRating.user(), validationRating.product()), validationRating.rating()))).join(predictionsJavaPairs).values();

		double meanSquaredError = JavaDoubleRDD.fromRDD(validationAndPredictions.map(pair -> {
			Double err = pair._1() - pair._2();
			return (Object) (err * err);// No covariance! Need to cast to Object
			}).rdd()).mean();
		double rmse = Math.sqrt(meanSquaredError);
		return rmse;

	}

	public String resultsReport(JavaRDD<Rating> predJavaRdd, CassandraJavaRDD<CassandraRow> validationsCassRdd, double rmse) {
		return "User\tProduct\tPredicted\tActual\tError?\n" + predictionString(predJavaRdd, validationsCassRdd) + "\n" + "RMSE = " + Util.round(rmse, 2);
	}

	private String predictionString(JavaRDD<Rating> predJavaRdd, CassandraJavaRDD<CassandraRow> validationsCassRdd) {
		java.util.function.Function<CassandraRow, Tuple2<Integer, Integer>> keyMapper = validationRow -> new Tuple2<Integer, Integer>(validationRow.getInt(RatingDO.USER_COL), validationRow.getInt(RatingDO.PRODUCT_COL));
		java.util.function.Function<CassandraRow, Double> valueMapper = validationRow -> validationRow.getDouble(RatingDO.RATING_COL);
		java.util.Map<Tuple2<Integer, Integer>, Double> validationMap = validationsCassRdd.collect().stream().collect(Collectors.toMap(keyMapper, valueMapper));

		java.util.function.Function<Rating, String> stringMapper = prediction -> {
			double validationRating = validationMap.get(new Tuple2<Integer, Integer>(prediction.user(), prediction.product()));
			String errWarningString = Math.abs(validationRating - prediction.rating()) >= 1 ? "ERR" : "OK";
			return prediction.user() + "\t" + prediction.product() + "\t" + Util.round(prediction.rating()) + "\t\t" + Util.round(validationRating) + "\t" + errWarningString;
		};
		Stream<Rating> sortedPredictions = predJavaRdd.collect().stream().sorted((o1, o2) -> o1.user() == o2.user() ? o1.product() - o2.product() : o1.user() - o2.user());
		String ret = sortedPredictions.map(stringMapper).collect(Collectors.joining("\n"));

		return ret;
	}

}