package org.gd.spark.opendl.example;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.gd.spark.opendl.downpourSGD.SampleVector;

public class DataInput {
	private static Random rand = new Random(System.currentTimeMillis());
	
	/**
	 * Read sample data from mnist_784 text file 
	 * @param path
	 * @return
	 * @throws Exception
	 */
	public static List<SampleVector> readMnist(String path, int x_feature, int y_feature) throws Exception {
		List<SampleVector> ret = new ArrayList<SampleVector>();
		String str = null;
        BufferedReader br = new BufferedReader(new FileReader(path));
        while (null != (str = br.readLine())) {
            String[] splits = str.split(",");
            SampleVector xy = new SampleVector(x_feature, y_feature);
            xy.getY()[Integer.valueOf(splits[0])] = 1;
            for (int i = 1; i < splits.length; i++) {
                xy.getX()[i - 1] = Double.valueOf(splits[i]);
            }
            ret.add(xy);
        }
        br.close();
		return ret;
	}

	/**
	 * Parallelize list to RDD
	 * @param context
	 * @param list
	 * @return
	 * @throws Exception
	 */
	public static JavaRDD<SampleVector> toRDD(JavaSparkContext context, List<SampleVector> list) throws Exception {
		return context.parallelize(list);
	}
	
	/**
	 * Split total list read from file to train and test part
	 * @param totalList
	 * @param trainList
	 * @param testList
	 * @param trainRatio
	 */
	public static void splitList(List<SampleVector> totalList, List<SampleVector> trainList, List<SampleVector> testList, double trainRatio) {
		for (SampleVector sample : totalList) {
			if (rand.nextDouble() <= trainRatio) {
				trainList.add(sample);
			}
			else {
				testList.add(sample);
			}
		}
	}
}