package edu.sdsc.mmtf.spark.mappers;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.rcsb.mmtf.api.StructureDataInterface;

import scala.Tuple2;

/**
 * Convert a full format of the file to a reduced format.
 * @author Peter Rose
 *
 */
public class StructureToInteractingResidues implements  FlatMapFunction<Tuple2<String,StructureDataInterface>, Row> {
	private static final long serialVersionUID = -3348372120358649240L;
    private String groupName;
    private double cutoffDistance;
    
    public StructureToInteractingResidues(String groupName, double cutoffDistance) {
		this.groupName = groupName;
		this.cutoffDistance = cutoffDistance;
	}
    
	@Override
	public Iterator<Row> call(Tuple2<String, StructureDataInterface> t) throws Exception {
		String structureId = t._1;
		StructureDataInterface structure = t._2;
		
		List<Integer> groupIndices = new ArrayList<>();
		List<String> groupNames = new ArrayList<>();		
		getGroupIndices(structure, groupIndices, groupNames);		

		List<Row> neighbors = new ArrayList<>();
		for (int i = 0; i < groupNames.size(); i++) {
			if (groupNames.get(i).equals(groupName)) {
				List<Integer> matches = new ArrayList<>();
				float[] boundingBox = calcBoundingBox(structure, groupIndices, i, cutoffDistance);
				matches.addAll(findNeighbors(structure, i, boundingBox, groupIndices));
				neighbors.addAll(getDistanceProfile(structureId, matches, i, groupIndices, groupNames, structure));
			}
		}
		
		return neighbors.iterator();
	}
	
	private List<Row> getDistanceProfile(String structureId, List<Integer> matches, int index, List<Integer> groupIndices, List<String> groupNames, StructureDataInterface structure) {
        double cutoffDistanceSq = cutoffDistance * cutoffDistance;
		
		float[] x = structure.getxCoords();
		float[] y = structure.getyCoords();
		float[] z = structure.getzCoords();
		
		int first = groupIndices.get(index);
		int last = groupIndices.get(index+1);
		
		List<Row> rows = new ArrayList<>();
		for (int i: matches) {
			if (i == index) {
				continue;
			}
			double minDSq = Double.MAX_VALUE;
			int minIndex = -1;
			for (int j = groupIndices.get(i); j < groupIndices.get(i+1); j++) {
				
				for (int k = first; k < last; k++) {
					double dx = (x[j] - x[k]);
					double dy = (y[j] - y[k]);
					double dz = (z[j] - z[k]);
					double dSq = dx*dx + dy*dy + dz*dz;
					if (dSq <= cutoffDistanceSq && dSq < minDSq) {
						minDSq = Math.min(minDSq, dSq);
						minIndex = i;
					}
				}
			}
			if (minIndex >= 0) {
				// TODO add unique group (and atom?) for each group?
				Row row = RowFactory.create(structureId, groupNames.get(index), index, groupNames.get(minIndex), minIndex, (float)Math.sqrt(minDSq));
				rows.add(row);
			}
		}
		return rows;
	}
	
	private List<Integer> findNeighbors(StructureDataInterface structure, int index, float[] boundingBox, List<Integer>groupIndices) {
		float[] x = structure.getxCoords();
		float[] y = structure.getyCoords();
		float[] z = structure.getzCoords();
		
		List<Integer> matches = new ArrayList<>();
		
		for (int i = 0; i < groupIndices.size()-1; i++) {
			for (int j = groupIndices.get(i); j < groupIndices.get(i+1); j++) {
				
				if (x[j] >= boundingBox[0] && x[j] <= boundingBox[1] &&
						y[j] >= boundingBox[2] && y[j] <= boundingBox[3] &&
						z[j] >= boundingBox[4] && z[j] <= boundingBox[5]) {
					matches.add(i);
					break;
				}
			}
		}
		return matches;
	}

	private float[] calcBoundingBox(StructureDataInterface structure, List<Integer> groupIndices, int index, double cutoffDistance) {	
		float[] x = structure.getxCoords();
		float[] y = structure.getyCoords();
		float[] z = structure.getzCoords();
		
		float xMin = Float.MAX_VALUE;
		float xMax = Float.MIN_VALUE;
		float yMin = Float.MAX_VALUE;
		float yMax = Float.MIN_VALUE;
		float zMin = Float.MAX_VALUE;
		float zMax = Float.MIN_VALUE;
		
		int first = groupIndices.get(index);	
		int last = groupIndices.get(index+1);
		
		for (int i = first; i < last; i++) {
			xMin = Math.min(xMin, x[i]);
			xMax = Math.max(xMax, x[i]);
			yMin = Math.min(yMin, y[i]);
			yMax = Math.max(yMax, y[i]);
			zMin = Math.min(zMin, z[i]);
			zMax = Math.max(zMax, z[i]);
		}
		
		float[] boundingBox = new float[6];
		boundingBox[0] = (float) (xMin - cutoffDistance);
		boundingBox[1] = (float) (xMax + cutoffDistance);
		boundingBox[2] = (float) (yMin - cutoffDistance);
		boundingBox[3] = (float) (yMax + cutoffDistance);
		boundingBox[4] = (float) (zMin - cutoffDistance);
		boundingBox[5] = (float) (zMax + cutoffDistance);
		
		return boundingBox;

	}

	private void getGroupIndices(StructureDataInterface structure, List<Integer> groupIndices,
			List<String> groupNames) {
		int atomCounter= 0;
		int groupCounter= 0;
		int numChains = structure.getChainsPerModel()[0];
		
		// add start index for first group
		groupIndices.add(0);

		for (int i = 0; i < numChains; i++) {			
			for (int j = 0; j < structure.getGroupsPerChain()[i]; j++) {
				int groupIndex = structure.getGroupTypeIndices()[groupCounter];
				groupNames.add(structure.getGroupName(groupIndex));
				atomCounter+= structure.getNumAtomsInGroup(groupIndex);
				groupIndices.add(atomCounter);
				groupCounter++;	
			}
		}
	}
}