package tberg.murphy.structpred;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;

import tberg.murphy.arrays.a;
import tberg.murphy.counter.CounterInterface;
import tberg.murphy.counter.IntCounter;
import tberg.murphy.tuple.Pair;

public class NSlackSVMLearner<T> implements LossAugmentedLearner<T> {

	public static class SvmOpts
	{
		public static double precision = 1e-20;

		public static double smoTol = 1e-4;

		public static int expGradIters = 0;

		public static int smoIters = 10000;

		public static int innerSmoIters = 10;

		public static boolean refreshAlphas = false;

		public static double lossAugCheckTol = 0.1;

		public static double newAlphaMag = 0.0;

		public static boolean smoCheckPrimal = true;

		public static int miniBatchSize = 8;

		public static int oneSlackCacheSize = 0;

		public static boolean primalSmo = true;

		public static int cacheWarmup = 5;

		public static double minDecodeToSmoTimeRatio = 1.0;

		public static boolean smoMiniBatch = true;

		public static boolean expGradient = false;

		public static int maxInactiveAlphaCount = Integer.MAX_VALUE;

		static public boolean projGradient = false;

		public static int innerInnerSmoIters = 3;

		public static double minCountThresh = 1e-6;
		
		public static boolean svmVerbose = true;

	}

	private double totalSMOTime;

	private double totalDecodeTime;

	int maxLength;

	double C;

	int N;

	double epsilon;

	List<Integer> activeAlphasPriorityQueue[];

	List<IntCounter>[] indexToDelta;

	List<Double>[] indexToDeltaNormSquared;

	List<Double>[] indexToAlpha;

	List<Double>[] indexToLoss;

	double[][] dotProdCache;

	int[][] alphaRelIndicesCache;

	int[] alphaAbsIndicesCacheInstance;

	int[] alphaAbsIndicesCacheConstraint;

	double[] weights;

	int numDecodes;

	int numFeaturesSoFar;

	int maxFeatureIndexSoFar = 0;;

	Random rand = new Random();

	protected SvmOpts opts;

	private LossAugmentedLinearModel<T> model;

	public NSlackSVMLearner(double C, double epsilon) {
		this(C, epsilon, new SvmOpts());
	}

	public NSlackSVMLearner(double C, double epsilon, SvmOpts opts) {
		this.C = C;
		this.epsilon = epsilon;
		this.opts = opts;
		numFeaturesSoFar = 0;
	}

	public CounterInterface<Integer> train(CounterInterface<Integer> initWeights, LossAugmentedLinearModel<T> model, List<T> data, int maxIters) {
		clearDotProductCache();
		clearIndicesCache();
		numDecodes = 0;
		totalDecodeTime = 0;
		totalSMOTime = 0;

		this.model = model;
		int miniBatchSize = Math.min(data.size(), opts.miniBatchSize);
		weights = toArray(initWeights);
		model.setWeights(wrapCounter(weights, maxFeatureIndexSoFar + 1));
		model.startIteration(0);
		for (int t = 0; t < maxIters; ++t) {
			System.out.println("Iteration " + t);
			int numAdded = 0;
			int currStart = 0;
			do {

				final int currEnd = Math.min(data.size(), currStart + miniBatchSize);
				List<T> currData = data.subList(currStart, currEnd);
				if (opts.svmVerbose) System.out.println("Decoding batch from " + currStart + " to " + currEnd);
				final int numMiniBatches = roundUp(data, miniBatchSize);
				final int currMiniBatchIndex = currStart / miniBatchSize;
				final boolean isFirst = t == 0 && currStart == 0;
				long startDecodeTime = System.currentTimeMillis();
				numAdded += reapConstraints(isFirst, model, currData, wrapCounter(weights, maxFeatureIndexSoFar + 1), currStart,
					numMiniBatches, currMiniBatchIndex, data.size());
				long endDecodeTime = System.currentTimeMillis();

				clearIndicesCache();

				if (opts.refreshAlphas || isFirst) {
					if (opts.expGradient) {
						uniformInitializeAlphas();
					} else {
						zeroInitializeAlphas();
					}
					weights = computeWeights();
					model.setWeights(wrapCounter(weights, maxFeatureIndexSoFar + 1));
				}
				long startSmoTime = System.currentTimeMillis();
				final int numSmoIters = opts.innerSmoIters;
				final int smoStart = opts.smoMiniBatch ? currStart : 0;
				final int smoEnd = opts.smoMiniBatch ? currEnd : indexToAlpha.length;
				final boolean checkConvergence = !opts.smoMiniBatch;
				optimizeDual(numSmoIters, smoStart, smoEnd, checkConvergence);
				long endSmoTime = System.currentTimeMillis();
				currStart += miniBatchSize;

				miniBatchSize = updateMiniBatchSize(miniBatchSize, currData, startDecodeTime, endDecodeTime, startSmoTime, endSmoTime);

				//				if (currentWeights != null) {
				//					if (!opts.primalSmo) {
				//						CounterInterface<Integer> weightsDelta = new IntCounter();
				//						weightsDelta.incrementAll(newWeights);
				//						weightsDelta.incrementAll(currentWeights, -1.0);
				//						System.out.printf("Mag of weights delta: %.8f\n", Math.sqrt(weightsDelta.dotProduct(weightsDelta)));
				//					}
				//					currentWeights = newWeights;
				//				} else {
				//					currentWeights = getWeights();
				//				}

			} while (currStart < data.size());

			System.out.println("Iteration ");

			if (numAdded > 0) {
				optimizeDual(opts.smoIters, 0, indexToAlpha.length, true);
				pruneInactiveAlphas();
				//				Logger.logss("True primal is " + computeTruePrimal(data));
			}

			model.setWeights(wrapCounter(weights, maxFeatureIndexSoFar + 1));
			assert checkWeights();
			model.startIteration(t + 1);
			if (numAdded == 0) break;

			if (opts.svmVerbose) {
				System.out.printf("Num constraints: %d\n", numConstraints());
				System.out.printf("Num decodes so far: %d\n", numDecodes);
				System.out.printf("Total SMO time so far: %f\n", totalSMOTime);
				System.out.printf("Total decode time so far: %f\n", totalDecodeTime);
			}
		}

		System.out.printf("Num decodes: %d\n", numDecodes);
		System.out.printf("Total SMO time: %f\n", totalSMOTime);
		System.out.printf("Total decode time: %f\n", totalDecodeTime);

		final IntCounter wrapCounter = wrapCounter(weights, maxFeatureIndexSoFar + 1);
		model.setWeights(wrapCounter);
		return wrapCounter;
	}

	private boolean checkWeights() {
		double[] computeWeights = computeWeights();
		for (int i = 0; i < maxFeatureIndexSoFar + 1; ++i) {
			assert Math.abs(computeWeights[i] - weights[i]) < 1e-3;
		}
		return true;
	}

	/**
	 * @param numSmoIters
	 * @param checkConvergence
	 * @return
	 */
	private void optimizeDual(final int numSmoIters, int start, int end, boolean checkConvergence) {
		if (numSmoIters == 0) return;
		if (opts.svmVerbose) if (checkConvergence) System.out.println("Optimizing dual");

		CounterInterface<Integer> newWeights = null;
		assert opts.primalSmo : "Dual SMO is broken";
		//		if (opts.primalSmo) {
		if (checkConvergence) {
			if (opts.projGradient) {
				optimizeDualObjectiveProjectedGradientPrimal(numSmoIters, start, end, checkConvergence);
			} else if (opts.expGradient) {
				optimizeDualObjectiveExponentiatedGradientPrimal(numSmoIters, start, end, checkConvergence);
			} else {
				//			optimizeDualObjectiveAnalytic(numSmoIters, start, end, checkConvergence);
				optimizeDualObjectiveSMOPrimalDual(numSmoIters, start, end, checkConvergence);
			}
		} else {
			optimizeDualObjectiveSMOPrimalDual(numSmoIters, start, end, checkConvergence);
		}
		//		} else {
		//			assert false : "Code is now broken";
		//			buildDotProdCache();
		//			optimizeDualObjectiveSMO();
		////			newWeights = getWeights();
		//
		//			System.out.printf("Primal objective: %.8f\n", getPrimalObjective());
		//			System.out.printf("Dual objective: %.8f\n", getDualObjective());
		//		}
		if (opts.svmVerbose) if (checkConvergence) System.out.println();
		//		return newWeights;
	}

	/**
	 * @param miniBatchSize
	 * @param currData
	 * @param startDecodeTime
	 * @param endDecodeTime
	 * @param startSmoTime
	 * @param endSmoTime
	 * @return
	 */
	private int updateMiniBatchSize(int miniBatchSize, List<T> currData, long startDecodeTime, long endDecodeTime, long startSmoTime, long endSmoTime) {
		final double currSmoTime = (endSmoTime - startSmoTime);
		final double currDecodeTime = endDecodeTime - startDecodeTime;
		totalSMOTime += currSmoTime;
		totalDecodeTime += currDecodeTime;
		double currDecodeToSmoRatio = currDecodeTime / currSmoTime;
		if (currDecodeToSmoRatio < opts.minDecodeToSmoTimeRatio) {
			double avgDecodeTime = currDecodeTime / currData.size();
			miniBatchSize = (int) Math.round(opts.minDecodeToSmoTimeRatio * currSmoTime / avgDecodeTime);
		}
		return miniBatchSize;
	}

	public double getDualObjective(double wNormSquared) {
		double obj = 0.0;
		obj += -0.5 * wNormSquared;
		for (int i = 0; i < indexToAlpha.length; ++i) {
			final int numConstraintsI = numConstraints(i);
			for (int yi = 0; yi < numConstraintsI; ++yi) {
				obj += (C / indexToAlpha.length) * indexToAlpha[i].get(yi) * indexToLoss[i].get(yi);
			}
		}
		return obj;
	}

	public double getDualObjectiveChange(int i, double[] alphas, double[] newAlphas, double[][] dotProdCaches) {
		double obj = 0.0;

		for (int yi = 0; yi < alphas.length; ++yi) {
			double ti = C / indexToAlpha.length * (newAlphas[yi] - alphas[yi]);
			obj -= 0.5 * indexToDeltaNormSquared[i].get(yi) * ti * ti;
			obj -= 0.5 * 2.0 * indexToDelta[i].get(yi).dotProduct(weights) * ti;
			for (int yj = yi + 1; yj < alphas.length; ++yj) {
				double tj = C / indexToAlpha.length * (newAlphas[yj] - alphas[yj]);
				obj -= 0.5 * 2.0 * dotProdCaches[yi][yj] * ti * tj;
			}
		}

		final int numConstraintsI = numConstraints(i);
		for (int yi = 0; yi < numConstraintsI; ++yi) {
			obj += (C / indexToAlpha.length) * (newAlphas[yi] - alphas[yi]) * indexToLoss[i].get(yi);
		}
		return obj;
	}

	private void clearIndicesCache() {
		alphaRelIndicesCache = null;
		alphaAbsIndicesCacheInstance = null;
		alphaAbsIndicesCacheConstraint = null;
	}

	private void clearDotProductCache() {
		dotProdCache = null;
	}

	protected static CounterInterface<Integer> getDelta(CounterInterface<Integer> gold, CounterInterface<Integer> guess) {
		CounterInterface<Integer> delta = new IntCounter();
		delta.incrementAll(gold);
		delta.incrementAll(guess, -1.0);
		return delta;
	}

	/**
	 * @param data
	 * @param miniBatchSize
	 * @return
	 */
	private int roundUp(List<T> data, int miniBatchSize) {
		return data.size() / miniBatchSize + (data.size() % miniBatchSize == 0 ? 0 : 1);
	}

	Pair<Integer, Integer> getAlphaRelativeIndicesFromAbsoluteIndex(final int absIndex_) {
		if (alphaAbsIndicesCacheInstance == null) {
			alphaAbsIndicesCacheInstance = new int[numConstraints()];
			alphaAbsIndicesCacheConstraint = new int[numConstraints()];
			Arrays.fill(alphaAbsIndicesCacheInstance, -1);
		}
		if (alphaAbsIndicesCacheInstance[absIndex_] < 0) {
			int i = 0;
			int absIndex = absIndex_;
			while (absIndex >= numConstraints(i)) {
				absIndex -= numConstraints(i);
				i++;
			}
			alphaAbsIndicesCacheInstance[absIndex_] = i;
			alphaAbsIndicesCacheConstraint[absIndex_] = absIndex;
		}
		return Pair.makePair(alphaAbsIndicesCacheInstance[absIndex_], alphaAbsIndicesCacheConstraint[absIndex_]);
	}

	int getAlphaAbsoluteIndexFromRelativeIndices(int i, int yi) {
		if (alphaRelIndicesCache == null) {
			alphaRelIndicesCache = new int[indexToAlpha.length][];
			for (int j = 0; j < indexToAlpha.length; ++j) {
				final int[] js = new int[numConstraints(j)];
				Arrays.fill(js, -1);
				alphaRelIndicesCache[j] = js;
			}

		}
		if (alphaRelIndicesCache[i][yi] < 0) {
			int absoluteIndex = 0;
			for (int j = 0; j < i; ++j) {
				absoluteIndex += numConstraints(j);
			}
			absoluteIndex += yi;
			alphaRelIndicesCache[i][yi] = absoluteIndex;
		}
		return alphaRelIndicesCache[i][yi];

	}

	public double getPrimalObjective() {
		//		double[] weights = toArray(getWeights());
		//		double[] weights = getWeights();

		return getPrimalObjective(weights, a.innerProd(weights, weights));
	}

	private double[] toArray(CounterInterface<Integer> weights) {
		double[] array = new double[maxFeatureIndexSoFar + 1];
		//		IntCounter.incrementDenseArray(array, weights, 1.0);
		int maxKey = weights.size();
		for (Entry<Integer, Double> entry : weights.entries()) {
			if (entry.getKey() >= array.length) {
				final int currKey = entry.getKey() + 1;
				maxKey = Math.max(currKey, maxKey);
				array = Arrays.copyOf(array, Math.max(currKey, array.length * 3 / 2));
			}
			array[entry.getKey()] = entry.getValue();
		}
		return array;//Arrays.copyOf(array, maxKey);
	}

	/**
	 * @param weights
	 * @return
	 */
	private double getPrimalObjective(double[] weights, double weightNormSquared) {
		double obj = 0.0;
		obj += 0.5 * weightNormSquared;
		double w2 = 0.5 * weightNormSquared;

		for (int i = 0; i < indexToAlpha.length; ++i) {
			double slack = 0.0;
			for (int yi = 0; yi < numConstraints(i); ++yi) {
				slack = Math.max(slack, getContraintSlack(i, yi, weights));
			}
			obj += (C / indexToAlpha.length) * slack;
		}
		return obj;
	}

	public double[] computeWeights() {
		IntCounter w = new IntCounter(numFeaturesSoFar);
		for (int i = 0; i < indexToAlpha.length; ++i) {
			final double currC = C / indexToAlpha.length;
			for (int yi = 0; yi < numConstraints(i); ++yi) {
				final double currAlpha = indexToAlpha[i].get(yi);
				assert currAlpha >= -1e-4;
				assert currAlpha <= 1 + 1e-4;
				for (Map.Entry<Integer, Double> entry : indexToDelta[i].get(yi).entries()) {
					maxFeatureIndexSoFar = Math.max(entry.getKey(), maxFeatureIndexSoFar);
					if (entry.getValue() != 0.0) w.incrementCount(entry.getKey(), currC * currAlpha * entry.getValue());
				}
			}
		}
		numFeaturesSoFar = Math.max(numFeaturesSoFar, w.size());
		return toArray(w);
	}

	double getContraintSlack(int i, int yi, CounterInterface<Integer> weights) {
		return Math.max(0.0, indexToLoss[i].get(yi) - indexToDelta[i].get(yi).dotProduct(weights));
	}

	double getContraintSlack(int i, int yi, double[] weights) {
		return Math.max(0.0, indexToLoss[i].get(yi) - indexToDelta[i].get(yi).dotProduct(weights));
	}

	double getContraintSlack(CounterInterface<Integer> weights, CounterInterface<Integer> delta, double loss) {
		return Math.max(loss - delta.dotProduct(weights), 0.0);
	}

	public double getDualObjective() {
		double obj = 0.0;
		final double C_norm = C / indexToAlpha.length;
		for (int i = 0; i < indexToAlpha.length; ++i) {
			final int numConstraintsI = numConstraints(i);
			for (int yi = 0; yi < numConstraintsI; ++yi) {
				final double[] dotProdCacheI = dotProdCache[getAlphaAbsoluteIndexFromRelativeIndices(i, yi)];
				final double alphaI = indexToAlpha[i].get(yi);
				obj += C_norm * alphaI * indexToLoss[i].get(yi);
				final double mult = 0.5 * C_norm * C_norm * alphaI;
				for (int j = 0; j < indexToAlpha.length; ++j) {
					final int numConstraintsJ = numConstraints(j);
					final List<Double> currAlphaJs = indexToAlpha[j];
					for (int yj = 0; yj < numConstraintsJ; ++yj) {
						final double alphaJ = currAlphaJs.get(yj);
						obj -= mult * alphaJ * dotProdCacheI[getAlphaAbsoluteIndexFromRelativeIndices(j, yj)];
					}
				}
			}
		}
		return obj;
	}

	double[] getDualGradient(int i) {
		double[] grad = new double[numConstraints(i)];
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			grad[yi] = (C / indexToAlpha.length) * indexToLoss[i].get(yi);
			for (int j = 0; j < indexToAlpha.length; ++j) {
				for (int yj = 0; yj < numConstraints(j); ++yj) {
					grad[yi] -= (C / indexToAlpha.length) * (C / indexToAlpha.length) * indexToAlpha[j].get(yj)
						* dotProdCache[getAlphaAbsoluteIndexFromRelativeIndices(i, yi)][getAlphaAbsoluteIndexFromRelativeIndices(j, yj)];
				}
			}
		}
		return grad;
	}

	double[] getDualGradient(int i, double[] weights) {
		double[] grad = new double[numConstraints(i)];
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			grad[yi] = (C / indexToAlpha.length) * indexToLoss[i].get(yi);
			grad[yi] -= (C / indexToAlpha.length) * indexToDelta[i].get(yi).dotProduct(weights);
			//			for (int j = 0; j < indexToAlpha.length; ++j) {
			//				for (int yj = 0; yj < numConstraints(j); ++yj) {
			//					grad[yi] -= (C / indexToAlpha.length) * (C / indexToAlpha.length) * indexToAlpha[j].get(yj)
			//						* dotProdCache[getAlphaAbsoluteIndexFromRelativeIndices(i, yi)][getAlphaAbsoluteIndexFromRelativeIndices(j, yj)];
			//				}
			//			}
		}
		return grad;
	}

	public void optimizeDualObjectiveExponentiatedGradient() {
		double[] stepSizes = new double[indexToAlpha.length];
		Arrays.fill(stepSizes, 0.1);
		for (int iter = 1; iter <= opts.expGradIters; iter++) {
			double objective = getDualObjective();
			for (int i = 0; i < indexToAlpha.length; ++i) {
				double[] alphas = getAlphas(i);
				double[] grad = getDualGradient(i);
				double newObjective = Double.NEGATIVE_INFINITY;
				while (true) {
					double[] direction = a.scale(grad, stepSizes[i]);
					double[] scale = a.exp(direction);
					double[] newAlphas = a.pointwiseMult(alphas, scale);
					normalize(newAlphas);
					setAlphas(i, newAlphas);
					newObjective = getDualObjective();
					boolean isFinite = !Double.isInfinite(newObjective) && !Double.isNaN(newObjective);
					if (isFinite && newObjective >= objective - opts.precision) {
						stepSizes[i] *= 1.1;
					} else {
						stepSizes[i] *= 0.5;
					}
					if (isFinite) break;
				}
				objective = newObjective;
			}
			if (iter == 1 || iter % 500 == 0 || iter == opts.expGradIters) System.out.printf(String.format("[ExpGrad] Iter %d: %.8f\n", iter, objective));
		}
	}

	public void optimizeDualObjectiveExponentiatedGradientPrimal(int maxIters, int start, int end, boolean checkConvergence) {

		double[] stepSizes = new double[indexToAlpha.length];
		Arrays.fill(stepSizes, 1.0);
		IntCounter deltaScratch = new IntCounter();
		for (int iter = 1; iter <= maxIters; iter++) {
			double weightsNormSquared = a.innerProd(weights, weights);

			double objective = getDualObjective(weightsNormSquared);
			for (int i = start; i < end; ++i) {
				double[] alphas = getAlphas(i);
				double[] grad = getDualGradient(i, weights);
				double newObjective = Double.NEGATIVE_INFINITY;
				double newWeightsNormSquared = Double.NaN;
				while (true) {
					double[] direction = a.scale(grad, stepSizes[i]);
					double maxDirection = a.max(direction);
					a.addi(direction, -maxDirection);

					double[] scale = a.exp(direction);

					double[] newAlphas = a.pointwiseMult(alphas, scale);
					normalize(newAlphas);
					for (int yi = 0; yi < newAlphas.length; ++yi) {
						newAlphas[yi] = Math.max(newAlphas[yi], Double.MIN_VALUE);
					}
					if (a.hasnan(newAlphas) || a.hasinf(newAlphas)) {
						stepSizes[i] *= 0.2;
						continue;
					}
					deltaScratch.clear();
					IntCounter weightsDelta = setAlphas(i, alphas, newAlphas, deltaScratch);
					if (checkConvergence) newWeightsNormSquared = weightsNormSquared + 2.0 * weightsDelta.dotProduct(weights) + weightsDelta.normSquared();
					if (checkConvergence) newObjective = getDualObjective(newWeightsNormSquared);

					boolean isFinite = !checkConvergence || !Double.isInfinite(newObjective) && !Double.isNaN(newObjective);
					if (isFinite && newObjective >= objective) {
						if (stepSizes[i] < 1e5) stepSizes[i] *= 2.0;
					} else {
						stepSizes[i] *= 0.2;
						if (stepSizes[i] == 0) break;
						continue;
					}
					incrementAll(weights, weightsDelta, 1.0);
					break;
				}
				objective = newObjective;
				weightsNormSquared = newWeightsNormSquared;
			}
			if (checkConvergence) {
				double primalObjective = getPrimalObjective(weights, weightsNormSquared);
				if (converged(primalObjective, objective, opts.smoTol, Double.NaN)) {
					break;
				}
				System.out.printf(String.format("[ExpGrad]  Iter %d: %.8f, %.8f\n", iter, objective, primalObjective));
			}
		}

	}

	public void optimizeDualObjectiveProjectedGradientPrimal(int maxIters, int start, int end, boolean checkConvergence) {

		if (checkConvergence) {
			System.out.println("Building delta cache");
		}
		double[][][] dotProdCaches = new double[end - start][][];
		for (int i = start; i < end; ++i) {
			final int numConstraints = numConstraints(i);
			dotProdCaches[i - start] = new double[numConstraints][numConstraints];

			for (int yi = 0; yi < numConstraints; ++yi) {
				IntCounter delta_yi = indexToDelta[i].get(yi);
				for (int yj = yi + 1; yj < numConstraints; ++yj) {

					IntCounter delta_yj = indexToDelta[i].get(yj);
					dotProdCaches[i - start][yi][yj] = delta_yi.dotProduct(delta_yj);

				}
			}
		}

		double[] stepSizes = new double[indexToAlpha.length];
		Arrays.fill(stepSizes, 1.0);
		for (int iter = 1; iter <= maxIters; iter++) {
			double weightsNormSquared = a.innerProd(weights, weights);

			double objective = getDualObjective(weightsNormSquared);
			for (int i = start; i < end; ++i) {
				double[] alphas = getAlphas(i);
				double[] grad = getDualGradient(i, weights);
				double objectiveChange = 0.0;
				while (true) {
					double[] newAlphas = projectToSimplex(toList(a.comb(alphas, 1.0, grad, stepSizes[i])));
					if (Arrays.equals(alphas, newAlphas)) break;
					objectiveChange = getDualObjectiveChange(i, alphas, newAlphas, dotProdCaches[i]);

					//					deltaScratch.clear();

					//					double weightsNormChange = 2.0 * weightsDelta.dotProduct(weights) + weightsDelta.normSquared();

					if (objectiveChange == 0.0)
						break;
					else if (objectiveChange > 0.0) {
						if (stepSizes[i] < 1e10) stepSizes[i] *= 1.2;
					} else {
						stepSizes[i] *= 0.5;
						if (stepSizes[i] == 0) break;
						continue;
					}
					for (int yi = 0; yi < newAlphas.length; ++yi) {
						indexToAlpha[i].set(yi, newAlphas[yi]);
						updateWeights(indexToDelta[i].get(yi), C / indexToAlpha.length * (newAlphas[yi] - alphas[yi]));
						//						incrementAll(weights, weightsDelta, 1.0);
					}
					break;
				}
				objective += objectiveChange;
				//				assert NumUtils.approxEquals(objective, getDualObjective(ArrayUtil.normSquared(weights)), 1e-4);
			}
			if (checkConvergence) {
				double primalObjective = getPrimalObjective(weights, a.innerProd(weights, weights));
				if (converged(primalObjective, objective, opts.smoTol, Double.NaN)) {
					System.out.printf(String.format("[ProjGrad]  Final Iter %d: %.8f, %.8f\n", iter, objective, primalObjective));
					break;
				}
				System.out.printf(String.format("[ProjGrad]  Iter %d: %.8f, %.8f\n", iter, objective, primalObjective));
			}
		}

	}

	private static List<Double> toList(double[] a) {
		List<Double> list = new ArrayList<Double>(a.length);
		for (double d : a)
			list.add(d);
		return list;
	}

	static void normalize(double[] vect) {
		double norm = 0.0;
		for (double val : vect)
			norm += val;
		for (int i = 0; i < vect.length; ++i) {
			if (norm > 0) vect[i] /= norm;
		}
	}

	void setAlphas(int i, double[] alphas) {
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			indexToAlpha[i].set(yi, alphas[yi]);
		}
	}

	IntCounter setAlphas(int i, double[] oldAlphas, double[] alphas, IntCounter weightsDelta) {
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			double step = alphas[yi] - oldAlphas[yi];
			if (step == 0.0) continue;
			indexToAlpha[i].set(yi, alphas[yi]);
			final double d = step * C / indexToAlpha.length;

			weightsDelta.ensureCapacity(weightsDelta.size() + indexToDelta[i].get(yi).size());
			weightsDelta.incrementAll(indexToDelta[i].get(yi), d);
		}
		return weightsDelta;
	}

	double[] getAlphas(int i) {
		double[] alphas = new double[numConstraints(i)];
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			alphas[yi] = indexToAlpha[i].get(yi);
		}
		return alphas;
	}

	public void optimizeDualObjectiveSMO() {
		double lastDual = Double.NEGATIVE_INFINITY;
		for (int iter = 1; iter <= opts.smoIters; ++iter) {
			for (int i = 0; i < indexToAlpha.length; ++i) {
				for (int yi = 0; yi < numConstraints(i); ++yi) {
					for (int yj = 0; yj < numConstraints(i); ++yj) {
						if (yi != yj) updateAlphas(i, yi, yj);
					}
				}
			}
			double dual = getDualObjective();
			double primal = opts.smoCheckPrimal ? getPrimalObjective() : Double.NaN;
			if (iter == 1 || iter % 200 == 0 || converged(primal, dual, opts.smoTol, lastDual) || iter == opts.smoIters)
				System.out.printf("[SMO] Round %d: %.8f\n", iter, dual);
			if (converged(primal, dual, opts.smoTol, lastDual)) break;
			lastDual = dual;
		}
	}

	private double computeTruePrimal(List<T> data) {
		double w2 = 0.5 * a.innerProd(weights, weights);
		double obj = w2;
		List<UpdateBundle> ubs = model.getLossAugmentedUpdateBundleBatch(data, 1.0);
		for (UpdateBundle b : ubs) {
			IntCounter delta = new IntCounter();
			delta.incrementAll(b.gold);
			delta.incrementAll(b.guess, -1.0);
			obj += getContraintSlack(IntCounter.wrapArray(weights, maxFeatureIndexSoFar + 1), delta, b.loss);
		}
		return obj;
	}

	public void optimizeDualObjectiveSMOPrimalDual(int maxIters, int start, int end, boolean checkConvergence) {
		if (checkConvergence) {

			if (opts.svmVerbose) System.out.println("Building delta cache");
		}
		double[][][] deltaDeltaNormCache = new double[end - start][][];
		double[][][] deltaDeltaDotProductCache = new double[end - start][][];
		for (int i = start; i < end; ++i) {
			final int numConstraints = numConstraints(i);
			deltaDeltaNormCache[i - start] = new double[numConstraints][numConstraints];
			deltaDeltaDotProductCache[i - start] = new double[numConstraints][numConstraints];

			for (int yi = 0; yi < numConstraints; ++yi) {
				final double delta_yi_sq = indexToDeltaNormSquared[i].get(yi);
				IntCounter delta_yi = indexToDelta[i].get(yi);
				for (int yj = yi + 1; yj < numConstraints; ++yj) {

					final double delta_yj_sq = indexToDeltaNormSquared[i].get(yj);
					IntCounter delta_yj = indexToDelta[i].get(yj);
					//							IntCounter delta = new IntCounter();
					//							delta.incrementAll(indexToDelta[i].get(yi));
					//							delta.incrementAll(indexToDelta[i].get(yj), -1.0);
					//							IntCounter delta_ = new IntCounter();
					//							delta_.incrementAll(delta);
					final double dotProd = delta_yi.dotProduct(delta_yj);
					deltaDeltaDotProductCache[i - start][yi][yj] = dotProd;
					deltaDeltaNormCache[i - start][yi][yj] = delta_yi_sq - 2.0 * dotProd + delta_yj_sq;

				}
			}
		}

		double lastDual = Double.NEGATIVE_INFINITY;

		for (int iter = 1; iter <= maxIters; ++iter) {
			boolean weightsChanged = false;

			for (int i : shuffle(start, end)) {
				final int currNumConstraints = numConstraints(i);
				final double[][] deltaDeltaNormCacheHere = deltaDeltaNormCache[i - start];
				final double[][] deltaDeltaDotProductCacheHere = deltaDeltaDotProductCache[i - start];
				//				INNER: for (int yi = 0; yi < 1/* currNumConstraints */; ++yi) {
				//					for (int yj = 1/* 0 */; yj < currNumConstraints; ++yj) {
				final List<Integer> shuffleOuter = shuffle(0, currNumConstraints);
				double[] deltWeightsDotProductCache = new double[currNumConstraints];
				for (int yi = 0; yi < currNumConstraints; ++yi) {
					deltWeightsDotProductCache[yi] = indexToDelta[i].get(yi).dotProduct(weights);
				}

				for (int innerIter = 0; innerIter < opts.innerInnerSmoIters; ++innerIter) {
					boolean innerWeightsChanged = false;
					final List<Integer> shuffleInner = shuffle(0, currNumConstraints);
					for (int yi : shuffleOuter) {
						for (int yj : shuffleInner) {
							if (yi < yj) {
								innerWeightsChanged |= updateAlphasPrimal(i, yi, yj, deltWeightsDotProductCache, deltaDeltaNormCacheHere,
									deltaDeltaDotProductCacheHere);
								//							if (weightsChanged) continue INNER;

							}
						}
					}
					if (!innerWeightsChanged)
						break;
					else
						weightsChanged = true;
				}
			}
			if (checkConvergence) {
				double wNormSquared = a.innerProd(weights, weights);//weights.dotProduct(weights);
				double dual = getDualObjective(wNormSquared);

				double primal = opts.smoCheckPrimal ? getPrimalObjective(weights, wNormSquared) : Double.NaN;
				if (opts.svmVerbose) System.out.printf("[SMO] Round %d: %.8f, %.8f\n", iter, dual, primal);

				if (converged(primal, dual, opts.smoTol, lastDual) || iter == maxIters) {
					if (opts.svmVerbose) System.out.printf("[SMO] Final Round %d: %.8f, %.8f\n", iter, dual, primal);
					break;
				}
				lastDual = dual;
			} else {
				if (!weightsChanged) {//
					break;
				}
			}
		}
	}

	//	public void optimizeDualObjectiveAnalytic(int maxIters, int start, int end, boolean checkConvergence) {
	//		//		if (checkConvergence) {
	//		//
	//		//			System.out.println("Building delta cache");
	//		//		}
	//		//		double[][][] deltaDeltaNormCache = new double[end - start][][];
	//		//		for (int i = start; i < end; ++i) {
	//		//			final int numConstraints = numConstraints(i);
	//		//			deltaDeltaNormCache[i - start] = new double[numConstraints][numConstraints];
	//		//
	//		//			for (int yi = 0; yi < numConstraints; ++yi) {
	//		//				final double delta_yi_sq = indexToDeltaNormSquared[i].get(yi);
	//		//				IntCounter delta_yi = indexToDelta[i].get(yi);
	//		//				for (int yj = yi + 1; yj < numConstraints; ++yj) {
	//		//
	//		//					final double delta_yj_sq = indexToDeltaNormSquared[i].get(yj);
	//		//					IntCounter delta_yj = indexToDelta[i].get(yj);
	//		//					//							IntCounter delta = new IntCounter();
	//		//					//							delta.incrementAll(indexToDelta[i].get(yi));
	//		//					//							delta.incrementAll(indexToDelta[i].get(yj), -1.0);
	//		//					//							IntCounter delta_ = new IntCounter();
	//		//					//							delta_.incrementAll(delta);
	//		//					deltaDeltaNormCache[i - start][yi][yj] = delta_yi_sq - 2.0 * delta_yi.dotProduct(delta_yj) + delta_yj_sq;
	//		//
	//		//				}
	//		//			}
	//		//		}
	//		//
	//		//		if (checkConvergence) Logger.endTrack();
	//		double lastDual = Double.NEGATIVE_INFINITY;
	//
	//		for (int iter = 1; iter <= maxIters; ++iter) {
	//			boolean weightsChanged = false;
	//			for (int i : shuffle(start, end)) {
	//				final int currNumConstraints = numConstraints(i);
	//				//				final double[][] deltaDeltaNormCacheHere = deltaDeltaNormCache[i - start];
	//				//				double[] cache = new double[currNumConstraints];
	//				//				Arrays.fill(cache, Double.NaN);
	//				//				INNER: for (int yi = 0; yi < 1/* currNumConstraints */; ++yi) {
	//				//					for (int yj = 1/* 0 */; yj < currNumConstraints; ++yj) {
	//				//				final List<Integer> shuffleOuter = shuffle(0, currNumConstraints);
	//				//				for (int yi : shuffleOuter) {
	//				//					final List<Integer> shuffleInner = shuffle(yi + 1, currNumConstraints);
	//				//					for (int yj : shuffleInner) {
	//				//						if (yi < yj) {
	//				weightsChanged |= updateAlphasAnalytic(i);
	//				//							if (weightsChanged) continue INNER;
	//
	//				//						}
	//				//					}
	//				//				}
	//			}
	//			if (checkConvergence) {
	//				double wNormSquared = ArrayUtil.normSquared(weights);//weights.dotProduct(weights);
	//				double dual = getDualObjective(wNormSquared);
	//
	//				double primal = opts.smoCheckPrimal ? getPrimalObjective(weights, wNormSquared) : Double.NaN;
	//				System.out.printf("[SMO] Round %d: %.8f, %.8f\n", iter, dual, primal);
	//
	//				if (converged(primal, dual, opts.smoTol, lastDual) || iter == maxIters) {
	//					System.out.printf("[SMO] Final Round %d: %.8f, %.8f\n", iter, dual, primal);
	//					break;
	//				}
	//				lastDual = dual;
	//			} else {
	//				if (!weightsChanged) {//
	//					break;
	//				}
	//			}
	//		}
	//	}

	//	private boolean updateAlphasAnalytic(int i) {
	//		double[] oldAlphas = getAlphas(i);
	//		final int currNumConstraints = numConstraints(i);
	//		List<Pair<Integer, Double>> Ds = new ArrayList<Pair<Integer, Double>>(currNumConstraints);
	//		for (int yi = 0; yi < currNumConstraints; ++yi) {
	//			final double deltaNormSquared = indexToDeltaNormSquared[i].get(yi);
	//			Ds.add(Pair.newPair(yi, deltaNormSquared == 0.0 ? 0.0 : (indexToLoss[i].get(yi) - indexToDelta[i].get(yi).dotProduct(weights)) / deltaNormSquared));
	//		}
	//		Collections.sort(Ds, new Comparator<Pair<Integer, Double>>()
	//		{
	//
	//			@Override
	//			public int compare(Pair<Integer, Double> arg0, Pair<Integer, Double> arg1) {
	//				return Double.compare(arg1.getSecond(), arg0.getSecond());
	//			}
	//		});
	//		int r = -1;
	//		double phi = 1.0;
	//		double lastPhi = phi;
	//		while (phi > 0.0) {
	//			r++;
	//			lastPhi = phi;
	//			phi = (r == currNumConstraints - 1) ? 0.0 : (phi - r * (Ds.get(r).getSecond() - Ds.get(r + 1).getSecond()));
	//		}
	//		double theta = Ds.get(r).getSecond() - lastPhi / r;
	//		double[] newAlphas = new double[currNumConstraints];
	//		for (int q = 0; q < currNumConstraints; ++q) {
	//			final int yi = Ds.get(q).getFirst();
	//			final double v = q < r ? theta : Ds.get(q).getSecond();
	//			final double alpha = -v;
	//			newAlphas[yi] = alpha;
	//		}
	//
	//		IntCounter weightsDelta = setAlphas(i, oldAlphas, newAlphas);
	//		incrementAll(weights, weightsDelta, 1.0);
	//
	//		return true;
	//
	//	}

	/**
	 * @param start
	 * @param end
	 * @return
	 */
	private List<Integer> shuffle(int start, int end) {
		List<Integer> result = a.toList(a.enumerate(start, end));
		Collections.shuffle(result);
		return result;
	}

	public static IntCounter toCounter(double[] weights) {
		IntCounter counter = new IntCounter(weights.length);
		for (int i = 0; i < weights.length; ++i) {
			counter.setCount(i, weights[i]);
		}
		return counter;
	}

	public static IntCounter wrapCounter(double[] weights, int size) {
		IntCounter counter = IntCounter.wrapArray(weights, size);
		return counter;
	}

	boolean converged(double primal, double dual, double tol, double lastDual) {
		if (opts.smoCheckPrimal) {
			double valueAverage = (Math.abs(dual) + Math.abs(primal)) / 2.0;
			if (Math.abs(primal - dual) < opts.precision || Math.abs(primal - dual) / valueAverage < tol) return true;
		} else {
			double diff = dual - lastDual;
			if (diff < opts.precision || diff / dual < opts.smoTol) return true;
		}
		return false;
	}

	public void updateAlphas(int i, int yi, int yj) {

		int yiAbs = getAlphaAbsoluteIndexFromRelativeIndices(i, yi);
		int yjAbs = getAlphaAbsoluteIndexFromRelativeIndices(i, yj);
		if (dotProdCache[yiAbs][yiAbs] == 0 && dotProdCache[yjAbs][yjAbs] == 0) return;
		double numerator = indexToLoss[i].get(yi) - indexToLoss[i].get(yj);
		double x = 0.0;
		double y = 0.0;
		for (int k = 0; k < indexToAlpha.length; ++k) {
			for (int yk = 0; yk < numConstraints(k); ++yk) {
				int ykAbs = getAlphaAbsoluteIndexFromRelativeIndices(k, yk);
				x -= (C / indexToAlpha.length) * indexToAlpha[k].get(yk) * dotProdCache[yiAbs][ykAbs];
				y += (C / indexToAlpha.length) * indexToAlpha[k].get(yk) * dotProdCache[yjAbs][ykAbs];
			}
		}
		double denomenator = 0.0;
		numerator += x + y;
		denomenator += (C / indexToAlpha.length) * dotProdCache[yiAbs][yiAbs];
		denomenator -= 2.0 * (C / indexToAlpha.length) * dotProdCache[yiAbs][yjAbs];
		denomenator += (C / indexToAlpha.length) * dotProdCache[yjAbs][yjAbs];
		if (denomenator == 0) return;

		double delta = Math.max(-indexToAlpha[i].get(yi), Math.min(indexToAlpha[i].get(yj), numerator / denomenator));
		indexToAlpha[i].set(yi, indexToAlpha[i].get(yi) + delta);
		indexToAlpha[i].set(yj, indexToAlpha[i].get(yj) - delta);
	}

	public boolean updateAlphasPrimal(int i, int yi, int yj, double[] deltaWeightsDotProdCache, double[][] deltaDeltaNormCache,
		double[][] deltaDeltaDotProductCache) {

		final double alpha_yi = indexToAlpha[i].get(yi);
		final double alpha_yj = indexToAlpha[i].get(yj);
		if (alpha_yi == 0.0 && alpha_yj == 0.0) return false;
		final double C_norm = C / indexToAlpha.length;
		double b = C_norm * deltaDeltaNormCache[yi][yj];
		if (b == 0.0) return false;

		final IntCounter delta_yi = indexToDelta[i].get(yi);
		final IntCounter delta_yj = indexToDelta[i].get(yj);
		final double delta_T_w_yi = deltaWeightsDotProdCache[yi];
		final double delta_T_w_yj = deltaWeightsDotProdCache[yj];
		final double deltaWeightsDotProd = delta_T_w_yi - delta_T_w_yj;
		double t = computeSmoStep(i, yi, yj, b, alpha_yi, alpha_yj, deltaWeightsDotProd);
		if (t == 0) return false;

		final double newAlphaYi = alpha_yi + t;
		final double newAlphaYj = alpha_yj - t;
		assert newAlphaYi >= -1e-7;
		assert newAlphaYi < 1 + 1e-7;
		assert newAlphaYj >= -1e-7;
		assert newAlphaYj < 1 + 1e-7;
		indexToAlpha[i].set(yi, newAlphaYi);
		indexToAlpha[i].set(yj, newAlphaYj);
		updateWeightsPair(delta_yi, delta_yj, t * C_norm);
		for (int yk = 0; yk < deltaWeightsDotProdCache.length; ++yk) {

			deltaWeightsDotProdCache[yk] += t * C_norm * getDotProd(i, yi, yk, deltaDeltaDotProductCache);
			deltaWeightsDotProdCache[yk] -= t * C_norm * getDotProd(i, yj, yk, deltaDeltaDotProductCache);
		}
		return true;
	}

	private double getDotProd(int i, int yj, int yk, double[][] deltaDeltaDotProductCache) {
		if (yj == yk) return indexToDeltaNormSquared[i].get(yj);
		if (yj < yk)
			return deltaDeltaDotProductCache[yj][yk];
		else
			return deltaDeltaDotProductCache[yk][yj];
	}

	/**
	 * @param C_norm
	 * @param delta_yi
	 * @param delta_yj
	 * @param t
	 */
	private void updateWeightsPair(final IntCounter delta_yi, final IntCounter delta_yj, double t) {
		updateWeights(delta_yi, t);
		updateWeights(delta_yj, -1.0 * t);
	}

	/**
	 * @param delta_yi
	 * @param t
	 */
	private void updateWeights(final IntCounter delta_yi, double t) {
		incrementAll(weights, delta_yi, t);
		model.updateWeights(delta_yi, t);
	}

	/**
	 * @param i
	 * @param yi
	 * @param yj
	 * @param b
	 * @param alpha_yi
	 * @param alpha_yj
	 * @param deltaWeightsDotProd
	 * @return
	 */
	private double computeSmoStep(int i, int yi, int yj, double b, final double alpha_yi, final double alpha_yj, final double deltaWeightsDotProd) {
		double a = -1.0 * deltaWeightsDotProd + indexToLoss[i].get(yi) - indexToLoss[i].get(yj);
		double c = -1.0 * alpha_yi;
		double d = alpha_yj;
		double t = Math.max(c, Math.min(d, a / b));
		return t;
	}

	private void incrementAll(final double[] weights, IntCounter currDelta, double d) {
		IntCounter.incrementDenseArray(weights, currDelta, d);
	}

	void clearConstraints(int numConstraintSets) {
		this.indexToDelta = new List[numConstraintSets];
		this.indexToDeltaNormSquared = new List[numConstraintSets];
		this.indexToAlpha = new List[numConstraintSets];
		this.indexToLoss = new List[numConstraintSets];
		this.activeAlphasPriorityQueue = new List[numConstraintSets];
		for (int i = 0; i < numConstraintSets; ++i) {
			this.indexToDelta[i] = new ArrayList<IntCounter>();
			this.activeAlphasPriorityQueue[i] = new ArrayList<Integer>();
			this.indexToDeltaNormSquared[i] = new ArrayList<Double>();
			this.indexToAlpha[i] = new ArrayList<Double>();
			this.indexToLoss[i] = new ArrayList<Double>();
		}
	}

	private void pruneInactiveAlphas() {
		if (opts.maxInactiveAlphaCount == Integer.MAX_VALUE) return;
		for (int i = 0; i < indexToAlpha.length; ++i) {
			for (int yi = 0; yi < numConstraints(i); ++yi) {
				if (indexToAlpha[i].get(yi) < 1e-30) {
					int curr = activeAlphasPriorityQueue[i].get(yi);
					if (curr > opts.maxInactiveAlphaCount) {
						deleteConstraint(i, yi);
					} else {
						activeAlphasPriorityQueue[i].set(yi, curr + 1);
					}
				} else {
					activeAlphasPriorityQueue[i].set(yi, 0);
				}

			}

		}
	}

	private void deleteConstraint(int i, int yi) {
		double alpha = indexToAlpha[i].remove(yi);
		indexToDeltaNormSquared[i].remove(yi);
		IntCounter delta_yi = indexToDelta[i].remove(yi);
		indexToLoss[i].remove(yi);
		updateWeights(delta_yi, -alpha);

	}

	public List<UpdateBundle> batchLossAugmentedDecode(LossAugmentedLinearModel<T> model, List<T> data, CounterInterface<Integer> weights, double lossWeight) {
		//				model.setWeights(weights);
		List<UpdateBundle> ubs = model.getLossAugmentedUpdateBundleBatch(data, lossWeight);
		numDecodes += data.size();
		return ubs;
	}

	public int reapConstraints(boolean initial, LossAugmentedLinearModel<T> model, List<T> data, CounterInterface<Integer> weights, int currStart,
		int numMiniBatches, int currMiniBatchIndex, int totalDataSize) {
		if (initial) {
			clearConstraints(totalDataSize);
			for (int i = 0; i < totalDataSize; ++i) {
				addConstraint(i, new IntCounter(), 0.0);
			}
		}

		int numAdded = 0;

		List<UpdateBundle> ubs = batchLossAugmentedDecode(model, data, weights, 1.0);

		for (int i = currStart; i < currStart + data.size(); ++i) {
			UpdateBundle ub = ubs.get(i - currStart);
			if (ub.loss == Double.POSITIVE_INFINITY) {
				System.out.println("Hmmm, infinite loss, ignoring");
				continue;
			}
			IntCounter delta = new IntCounter();
			delta.incrementAll(ub.gold);
			delta.incrementAll(ub.guess, -1.0);
			double loss = ub.loss;

			numAdded += addConstraintIfNecessary(weights, i, delta, loss);
		}

		return numAdded;
	}

	/**
	 * @param weights
	 * @param numAdded
	 * @param i
	 * @param delta
	 * @param loss
	 * @return
	 */
	protected int addConstraintIfNecessary(CounterInterface<Integer> weights, int i, IntCounter delta, double loss) {
		int numAdded = 0;
		double currentSlack = Double.NEGATIVE_INFINITY;
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			currentSlack = Math.max(currentSlack, getContraintSlack(i, yi, weights));
		}
		final double newSlack = getContraintSlack(weights, delta, loss);
		if (shouldCheckLossAugmentedDecoding() && newSlack < currentSlack - opts.lossAugCheckTol * currentSlack) {
			//throw new RuntimeException("Something probably wrong with loss-augmented decoding"); 
			System.out.println("Something wrong with loss augmented decoding, new slack is " + newSlack + " and current slack is " + currentSlack);
		}
		if (newSlack > currentSlack + epsilon) {
			addConstraint(i, delta, loss);
			numAdded += 1;
		}
		return numAdded;
	}

	protected boolean shouldCheckLossAugmentedDecoding() {
		return true;
	}

	public void buildDotProdCache() {
		if (!opts.primalSmo) {

			System.out.println("Building dotprod cache");
			final int numConstraints = numConstraints();
			this.dotProdCache = new double[numConstraints][numConstraints];
			for (int i = 0; i < numConstraints; ++i) {
				Pair<Integer, Integer> relIndicesI = getAlphaRelativeIndicesFromAbsoluteIndex(i);
				final CounterInterface<Integer> deltaI = indexToDelta[relIndicesI.getFirst()].get(relIndicesI.getSecond());
				for (int j = 0; j < numConstraints; ++j) {
					Pair<Integer, Integer> relIndicesJ = getAlphaRelativeIndicesFromAbsoluteIndex(j);
					dotProdCache[i][j] = deltaI.dotProduct(indexToDelta[relIndicesJ.getFirst()].get(relIndicesJ.getSecond()));
				}
			}
		}
	}

	public void zeroInitializeAlphas() {
		for (int i = 0; i < indexToAlpha.length; ++i) {
			for (int yi = 0; yi < numConstraints(i); ++yi) {
				if (yi == 0) {
					indexToAlpha[i].set(yi, 1.0);
				} else {
					indexToAlpha[i].set(yi, 0.0);
				}
			}
		}
	}

	public void uniformInitializeAlphas() {
		for (int i = 0; i < indexToAlpha.length; ++i) {
			for (int yi = 0; yi < numConstraints(i); ++yi) {
				if (yi == 0) {
					indexToAlpha[i].set(yi, 0.9);
				} else {
					indexToAlpha[i].set(yi, 0.1 / (numConstraints(i) - 1.0));
				}
			}
		}
	}

	public void addConstraint(int i, IntCounter delta, double loss) {
		indexToAlpha[i].add(0.0);
		activeAlphasPriorityQueue[i].add(0);
		normalizeAlphas(i);

		double maxFeatureCount = 0.0;
		for (Map.Entry<Integer, Double> entry : delta.entries()) {
			maxFeatureIndexSoFar = Math.max(entry.getKey(), maxFeatureIndexSoFar);
			maxFeatureCount = Math.abs(Math.max(maxFeatureCount, entry.getValue()));
		}
		maxFeatureCount = Math.min(1.0, maxFeatureCount);
		if (maxFeatureIndexSoFar >= weights.length) {
			weights = Arrays.copyOf(weights, Math.max(maxFeatureIndexSoFar + 1, weights.length * 3 / 2));
		}

		// prune low count features. Really helps for small expected counts
		List<Integer> toClear = new ArrayList<Integer>();
		for (Entry<Integer,Double> entry : delta.entries()) {
			if (Math.abs(entry.getValue()) < maxFeatureCount * opts.minCountThresh) {
				toClear.add(entry.getKey());
			}

		}
		for (int key : toClear)
			delta.setCount(key, 0.0);
		IntCounter fresh = new IntCounter();
		fresh.incrementAll(delta);
		delta = null;
		fresh.toSorted();
		indexToDelta[i].add(fresh);
		indexToDeltaNormSquared[i].add(fresh.normSquared());
		indexToLoss[i].add(loss);
		double[] alphas = getAlphas(i);
		double[] newAlphas = Arrays.copyOf(alphas, alphas.length);
		newAlphas[newAlphas.length - 1] = opts.newAlphaMag;

		IntCounter weightsDelta = setAlphas(i, alphas, newAlphas, new IntCounter());
		incrementAll(weights, weightsDelta, 1.0);
	}

	void normalizeAlphas(int i) {
		double norm = 0.0;
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			norm += indexToAlpha[i].get(yi);
		}
		for (int yi = 0; yi < numConstraints(i); ++yi) {
			if (norm > 0) indexToAlpha[i].set(yi, indexToAlpha[i].get(yi) / norm);
		}
	}

	public int numConstraints(int i) {
		if (indexToAlpha == null) return 0;
		return indexToAlpha[i].size();
	}

	public int numConstraints() {
		if (indexToAlpha == null) return 0;
		int result = 0;
		for (int i = 0; i < indexToAlpha.length; ++i) {
			result += numConstraints(i);
		}
		return result;
	}

	static <D> int index(D thing, List<D> indexToThing, Map<D, Integer> thingToIndex) {
		Integer index = thingToIndex.get(thing);
		if (index == null) {
			index = indexToThing.size();
			thingToIndex.put(thing, index);
			indexToThing.add(thing);
		}
		return index;
	}

	public static double[] projectToSimplex(List<Double> v) {
		List<Double> u = new ArrayList<Double>(v);
		Collections.sort(u);
		Collections.reverse(u);
		int p = 0;
		for (int j = 0; j < u.size(); ++j) {
			double sum = 0.0;
			for (int r = 0; r <= j; ++r) {
				sum += u.get(r);
			}
			if (u.get(j) - (1.0 / (j + 1.0)) * (sum - 1.0) > 0) {
				p = Math.max(p, j);
			}
		}
		double sum = 0.0;
		for (int i = 0; i <= p; ++i) {
			sum += u.get(i);
		}
		double theta = (1.0 / (p + 1.0)) * (sum - 1.0);
		for (int i = 0; i < v.size(); ++i) {
			v.set(i, Math.max(v.get(i) - theta, 0));
		}

		double sumv = 0.0;
		for (double val : v) {
			sumv += val;
		}
		assert sumv > 1.0 - 1e-6 && sumv < 1.0 + 1e-6;
		double[] ret = new double[v.size()];
		for (int i = 0; i < v.size(); ++i)
			ret[i] = v.get(i);
		return ret;
	}

}