package io.izenecloud.lr;

import io.izenecloud.larser.feature.OnlineVectorWritable;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;

import edu.stanford.nlp.optimization.QNMinimizer;

public class LrIterationMapper extends
		Mapper<Text, ListWritable, String, LaserOnlineModel> {
	private static final double DEFAULT_REGULARIZATION_FACTOR = 0.000001f;

	private double regularizationFactor;
	private QNMinimizer lbfgs;

	protected void setup(Context context) throws IOException,
			InterruptedException {
		Configuration conf = context.getConfiguration();
		regularizationFactor = conf.getDouble(
				"lr.iteration.regulariztion.factor",
				DEFAULT_REGULARIZATION_FACTOR);
		lbfgs = new QNMinimizer();
		lbfgs.setRobustOptions();
	}

	protected void map(Text key, ListWritable valueWritable, Context context)
			throws IOException, InterruptedException {
		OnlineVectorWritable[] inputSplitData = new OnlineVectorWritable[valueWritable
				.get().size()];

		List<Writable> value = valueWritable.get();
		Iterator<Writable> iterator = value.iterator();
		int row = 0;
		while (iterator.hasNext()) {
			OnlineVectorWritable v = ((OnlineVectorWritable) (iterator.next()));
			inputSplitData[row] = v;
			row++;
		}
		LrIterationMapContext mapContext = new LrIterationMapContext(
				inputSplitData);
		mapContext = localMapperOptimization(mapContext);

		double[] x = mapContext.getX();
		context.write(key.toString(), new LaserOnlineModel(x));
	}

	private LrIterationMapContext localMapperOptimization(
			LrIterationMapContext context) {
		LogisticL2DiffFunction logistic = new LogisticL2DiffFunction(
				context.getA(), context.getB(), context.getKnowOffset(),
				context.getX(), regularizationFactor);
		double[] optimum = lbfgs.minimize(logistic, 1e-6, context.getX());
		context.setX(optimum);
		return context;
	}
}