package edu.sdsc.mmtf.spark.alignments;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import javax.vecmath.Point3d;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.rcsb.mmtf.api.StructureDataInterface;

import edu.sdsc.mmtf.spark.mappers.StructuralAlignmentMapper;
import edu.sdsc.mmtf.spark.utils.ColumnarStructureX;
import scala.Tuple2;

/**
 * This class performs parallel structure alignments. It performs
 * all vs. all and query set vs. all alignments.
 * 
 * @author Peter Rose
 * @since 0.2.0
 *
 */
public class StructureAligner implements Serializable {
	private static final long serialVersionUID = -7649106216436396239L;
	private static int NUM_TASKS = 3; // number of tasks per partition, Spark doc. suggests to use around 3.

	/**
	 * Calculates all vs. all structural alignments of protein chains using the 
	 * specified alignment algorithm. The input structures must contain single 
	 * protein chains.
	 * 
	 * @param targets structures containing single protein chains
	 * @param alignmentAlgorithm name of the algorithm
	 * @return dataset with alignment metrics
	 */
	public static Dataset<Row> getAllVsAllAlignments(JavaPairRDD<String, StructureDataInterface> targets,
			String alignmentAlgorithm) {

		SparkSession session = SparkSession.builder().getOrCreate();
		JavaSparkContext sc = new JavaSparkContext(session.sparkContext());

		// create a list of chainName/ C Alpha coordinates
		List<Tuple2<String, Point3d[]>> chains  = targets.mapValues(
				s -> new ColumnarStructureX(s,true).getcAlphaCoordinates()).collect();

		// create an RDD of all pair indices (0,1), (0,2), ..., (1,2), (1,3), ...
		JavaRDD<Tuple2<Integer, Integer>> pairs = getPairs(sc, chains.size());
		
		// calculate structural alignments for all pairs.
		// broadcast (copy) chains to all worker nodes for efficient processing.
		// for each pair there can be zero or more solutions, therefore we flatmap the pairs.
		JavaRDD<Row> rows = pairs.flatMap(new StructuralAlignmentMapper(sc.broadcast(chains), alignmentAlgorithm));

		// convert rows to a dataset
		return session.createDataFrame(rows, getSchema());
	}

	/**
	 * Calculates structural alignments between a query and a target set of protein chains 
	 * using the specified alignment algorithm. An input structures must contain single 
	 * protein chains.
	 * 
	 * @param targets structures containing single protein chains
	 * @param alignmentAlgorithm name of the algorithm
	 * @return dataset with alignment metrics
	 */
	public static Dataset<Row> getQueryVsAllAlignments(
			JavaPairRDD<String, StructureDataInterface> queries, 
			JavaPairRDD<String, StructureDataInterface> targets,
			String alignmentAlgorithm) {

		SparkSession session = SparkSession.builder().getOrCreate();
		@SuppressWarnings("resource") // spark context should not be closed here
		JavaSparkContext sc = new JavaSparkContext(session.sparkContext());

		List<Tuple2<String, Point3d[]>> chains = new ArrayList<>();

		// create a list of chainName/ C Alpha coordinates for query chains
		chains.addAll(queries.mapValues(
				s -> new ColumnarStructureX(s,true).getcAlphaCoordinates()).collect());

		int querySize = chains.size();

		// create a list of chainName/ C Alpha coordinates for target chains
		chains.addAll(targets.mapValues(
				s -> new ColumnarStructureX(s,true).getcAlphaCoordinates()).collect());

		// create an RDD with indices for all query - target pairs (q, t)
		List<Tuple2<Integer, Integer>> pairList = new ArrayList<>(chains.size());
		for (int q = 0; q < querySize; q++) {
			for (int t = querySize; t < chains.size(); t++) {
				pairList.add(new Tuple2<Integer, Integer>(q, t));
			}
		}
		JavaRDD<Tuple2<Integer, Integer>> pairs = sc.parallelize(pairList, NUM_TASKS*sc.defaultParallelism());
		
		// calculate structural alignments for all pairs.
		// the chains are broadcast (copied) to all worker nodes for efficient processing
		JavaRDD<Row> rows = pairs.flatMap(new StructuralAlignmentMapper(sc.broadcast(chains), alignmentAlgorithm));

		// convert rows to a dataset
		return session.createDataFrame(rows, getSchema());
	}

	/**
	 * Creates the schema for the alignment dataset.
	 * @return Schema for the alignment dataset
	 */
	private static StructType getSchema() {
		boolean nullable = false;
		StructField[] sf = {
				DataTypes.createStructField("id", DataTypes.StringType, nullable),
				DataTypes.createStructField("length", DataTypes.IntegerType, nullable),
				DataTypes.createStructField("coverage1", DataTypes.IntegerType, nullable),
				DataTypes.createStructField("coverage2", DataTypes.IntegerType, nullable),
				DataTypes.createStructField("rmsd", DataTypes.FloatType, nullable),
				DataTypes.createStructField("tm", DataTypes.FloatType, nullable)
		};

		return DataTypes.createStructType(sf);
	}

	/**
	 * Creates an RDD of all n*(n-1)/2 unique pairs for pairwise structural alignments.
	 * @param sc spark context
	 * @param n number of protein chains
	 * @return
	 */
	private static JavaRDD<Tuple2<Integer, Integer>> getPairs(JavaSparkContext sc, int n) {
		// create a list of integers from 0 - n-1
		List<Integer> range = IntStream.range(0, n).boxed().collect(Collectors.toList());

		JavaRDD<Integer> pRange = sc.parallelize(range, NUM_TASKS*sc.defaultParallelism());

		// flatmap this list of integers into all unique pairs 
		// (0,1),(0,2),...(0,n-1),  (1,2)(1,3),..,(1,n-1),  (2,3),(2,4),...
		return pRange.flatMap(new FlatMapFunction<Integer, Tuple2<Integer,Integer>>() {
			private static final long serialVersionUID = -432662341173300339L;

			@Override
			public Iterator<Tuple2<Integer, Integer>> call(Integer t) throws Exception {
				List<Tuple2<Integer, Integer>> pairs = new ArrayList<>();

				for (int i = 0; i < t; i++) {
					pairs.add(new Tuple2<Integer, Integer>(i, t));
				}
				return pairs.iterator();
			}
			// The partitions generated here are not well balanced, which would lead to an
			// unbalanced workload. Here we repartition the pairs for efficient processing.
		}).repartition(NUM_TASKS*sc.defaultParallelism()); 
	}
}