package edu.neu.ccs.pyramid.multilabel_classification.crf; import edu.neu.ccs.pyramid.dataset.DataSet; import edu.neu.ccs.pyramid.dataset.MultiLabel; import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet; import edu.neu.ccs.pyramid.dataset.SequentialSparseDataSet; import edu.neu.ccs.pyramid.optimization.Terminator; import edu.neu.ccs.pyramid.util.MathUtil; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.SequentialAccessSparseVector; import org.apache.mahout.math.Vector; import java.util.*; import java.util.stream.IntStream; /** * Created by Rainicy on 11/10/16. * * Train CMLCRF using ElasticNet by coordinate Descent */ public class CMLCRFElasticNet { private Terminator terminator; private CMLCRF cmlcrf; private List<MultiLabel> supportedCombinations; private int numSupport; private MultiLabelClfDataSet dataSet; private int numClasses; private int numParameters; private int numWeightsForFeatures; private int numWeightsForLabelPairs; private double value; private Vector empiricalCounts; private Vector predictedCounts; private int[] parameterToL1; private int[] parameterToL2; private int[] parameterToClass; private int[] parameterToFeature; // whether the support combination contains the label; // size num combination* num classes private boolean[][] comContainsLabel; private boolean isParallel = true; private boolean isValueCacheValid = false; // numDataPoints by numClasses; private double[][] classScoreMatrix; // numDataPoints by numClasses; private double[][] classProbMatrix; // numDataPoints by numCombinations private double[][] combProbMatrix; // numDataPoints by numCombinations private double[][] combScoreMatrix; private int numData; // for each data point, store the position of the true combination in the support list private int[] labelComIndices; private double l1Ratio; private double regularization; private int numFeature; // label-lable features // size = numSupport * (for each support, label-label feature is non-zero index, starting from 0) private List<List<Integer>> combinationToLabelPair; //for each label pair (index), map to the list of matched combinations (index) // number of pairs * variable length private List<List<Integer>> labelPairToCombination; // for each combination, store the sum of probabilities over all data points // size = num combinations private double[] combProbSums; public CMLCRFElasticNet (CMLCRF cmlcrf, MultiLabelClfDataSet dataSet, double l1Ratio, double regularization) { this.l1Ratio = l1Ratio; this.regularization = regularization; this.terminator = new Terminator(); this.terminator.setGoal(Terminator.Goal.MINIMIZE); this.numFeature = dataSet.getNumFeatures(); this.cmlcrf = cmlcrf; this.supportedCombinations = cmlcrf.getSupportCombinations(); this.numSupport = cmlcrf.getNumSupports(); this.dataSet = dataSet; this.numData = dataSet.getNumDataPoints(); this.numClasses = dataSet.getNumClasses(); this.numParameters = cmlcrf.getWeights().totalSize(); this.numWeightsForFeatures = cmlcrf.getWeights().getNumWeightsForFeatures(); this.numWeightsForLabelPairs = cmlcrf.getWeights().getNumWeightsForLabels(); this.classScoreMatrix = new double[numData][numClasses]; this.classProbMatrix = new double[numData][numClasses]; this.combScoreMatrix = new double[numData][numSupport]; this.combProbMatrix = new double[numData][numSupport]; this.isValueCacheValid = false; this.empiricalCounts = new DenseVector(numParameters); this.predictedCounts = new DenseVector(numParameters); this.initCache(); this.updateEmpiricalCounts(); this.combinationToLabelPair = new ArrayList<>(numSupport); for (int i=0;i< numSupport;i++) { combinationToLabelPair.add(new LinkedList<>()); } this.labelPairToCombination = new ArrayList<>(numWeightsForLabelPairs); for (int i=0;i< numWeightsForLabelPairs;i++){ labelPairToCombination.add(new ArrayList<>()); } this.mapCombinattionToPair(); this.mapPairToCombination(); this.combProbSums = new double[numSupport]; Map<MultiLabel,Integer> map = new HashMap<>(); for (int s=0;s< numSupport;s++){ map.put(supportedCombinations.get(s),s); } this.labelComIndices = new int[dataSet.getNumDataPoints()]; for (int i=0;i<dataSet.getNumDataPoints();i++){ labelComIndices[i] = map.get(dataSet.getMultiLabels()[i]); } } public void optimize() { while (true) { iterate(); if (terminator.shouldTerminate()) { break; } } } public void iterate() { // System.out.println("weights: " + cmlcrf.getWeights().getAllWeights()); // O(NdL) // System.out.println(Arrays.toString(cmlcrf.getCombinationLabelPartScores())); updateClassScoreMatrix(); cmlcrf.updateCombLabelPartScores(); updateAssignmentScoreMatrix(); updateAssignmentProbMatrix(); updateCombProbSums(); updatePredictedCounts(); updateClassProbMatrix(); // update for each support label set Vector accumulateWeights = new SequentialAccessSparseVector(numParameters); Vector oldWeights = cmlcrf.getWeights().deepCopy().getAllWeights(); for (int l=0; l<numSupport; l++) { // System.out.println("label: " + supportedCombinations.get(l)); DataSet newData = expandData(l); iterateForOneComb(newData, l); accumulateWeights = accumulateWeights.plus(cmlcrf.getWeights().getAllWeights()); cmlcrf.getWeights().setWeightVector(oldWeights); } // lineSearch if (true) { Vector searchDirection = accumulateWeights; Vector gradient = this.predictedCounts.minus(empiricalCounts).divide(numData); lineSearch(searchDirection, gradient); } this.terminator.add(getValue()); } private DataSet expandData(int l) { SequentialSparseDataSet newData = new SequentialSparseDataSet(numData, numParameters, false); MultiLabel label = supportedCombinations.get(l); List<Integer> labelPairForL = combinationToLabelPair.get(l); // TODO: parallelism for (int i=0; i<numData; i++) { // add feature-label feature for (int y : label.getMatchedLabels()) { // set bias as 1 newData.setFeatureValue(i, (numFeature+1)*y, 1.0); for (Vector.Element element : dataSet.getRow(i).nonZeroes()) { int index = element.index(); double value = element.get(); newData.setFeatureValue(i, (numFeature+1)*y+index+1, value); } } for (int y : labelPairForL) { newData.setFeatureValue(i, (numWeightsForFeatures+y), 1.0); } } return newData; } // update private void iterateForOneComb(DataSet newData, int l) { double[] realLabels = new double[numData]; double[] instanceWeights = new double[numData]; IntStream.range(0, numData).parallel().forEach(i -> { double prob = combProbMatrix[i][l]; double classScore = combScoreMatrix[i][l]; int y = labelComIndices[i]; double frac = 0; double tmpP = prob * (1-prob); int indicator = (y==l)?1:0; if (prob!=0&&prob!=1) { frac = (indicator-prob) / tmpP; } if (frac>1) { frac=1; } if (frac<-1) { frac=-1; } realLabels[i] = classScore + frac; instanceWeights[i] = tmpP; }); CRFLinearRegression linearRegression = new CRFLinearRegression(numParameters,cmlcrf.getWeights().getAllWeights()); CRFElasticNetLinearRegOptimizer linearRegTrainer = new CRFElasticNetLinearRegOptimizer(linearRegression, newData, realLabels, instanceWeights); linearRegTrainer.setRegularization(regularization); linearRegTrainer.setL1Ratio(l1Ratio); linearRegTrainer.optimize(); isValueCacheValid = false; } private void updateClassScoreMatrix(){ IntStream.range(0,dataSet.getNumDataPoints()).parallel() .forEach(i-> classScoreMatrix[i] = cmlcrf.predictClassScores(dataSet.getRow(i))); } private void updateAssignmentScoreMatrix(){ IntStream.range(0,dataSet.getNumDataPoints()).parallel() .forEach(i -> combScoreMatrix[i] = cmlcrf.predictCombinationScores(classScoreMatrix[i])); } private void updateAssignmentProbMatrix(){ IntStream.range(0,dataSet.getNumDataPoints()).parallel() .forEach(i -> combProbMatrix[i] = cmlcrf.predictCombinationProbs(combScoreMatrix[i])); } private void initCache() { parameterToL1 = new int[numWeightsForLabelPairs]; parameterToL2 = new int[numWeightsForLabelPairs]; int start = 0; for (int l1=0; l1<numClasses; l1++) { for (int l2=l1+1; l2<numClasses; l2++) { parameterToL1[start] = l1; parameterToL1[start+1] = l1; parameterToL1[start+2] = l1; parameterToL1[start+3] = l1; parameterToL2[start] = l2; parameterToL2[start+1] = l2; parameterToL2[start+2] = l2; parameterToL2[start+3] = l2; start += 4; } } parameterToClass = new int[numWeightsForFeatures]; parameterToFeature = new int[numWeightsForFeatures]; for (int i=0; i<numWeightsForFeatures; i++) { parameterToClass[i] = cmlcrf.getWeights().getClassIndex(i); parameterToFeature[i] = cmlcrf.getWeights().getFeatureIndex(i); } comContainsLabel = new boolean[numSupport][numClasses]; for (int num=0; num< numSupport; num++) { for (int l=0; l<numClasses; l++) { if (supportedCombinations.get(num).matchClass(l)) { comContainsLabel[num][l] = true; } } } } private void mapPairToCombination(){ IntStream.range(0, numWeightsForLabelPairs).parallel().forEach(this::mapPairToCombination); } private void mapPairToCombination(int position) { List<Integer> list = labelPairToCombination.get(position); int l1 = parameterToL1[position]; int l2 = parameterToL2[position]; int featureCase = position % 4; for (int c=0; c< numSupport; c++) { switch (featureCase) { // both l1, l2 equal 0; case 0: if (!comContainsLabel[c][l1] && !comContainsLabel[c][l2]) list.add(c); break; // l1 = 1; l2 = 0; case 1: if (comContainsLabel[c][l1] && !comContainsLabel[c][l2]) list.add(c); break; // l1 = 0; l2 = 1; case 2: if (!comContainsLabel[c][l1] && comContainsLabel[c][l2]) list.add(c); break; // l1 = 1; l2 = 1;mapPairToCombination case 3: if (comContainsLabel[c][l1] && comContainsLabel[c][l2]) list.add(c); break; default: throw new RuntimeException("feature case :" + featureCase + " failed."); } } } private void mapCombinattionToPair() { IntStream.range(0, numSupport).forEach(this::mapCombinattionToPair); } private void mapCombinattionToPair(int s) { for (int position=0; position<numWeightsForLabelPairs; position++){ int l1 = parameterToL1[position]; int l2 = parameterToL2[position]; int featureCase = position % 4; switch (featureCase) { // both l1, l2 equal 0; case 0: if (!comContainsLabel[s][l1] && !comContainsLabel[s][l2]) combinationToLabelPair.get(s).add(position); break; // l1 = 1; l2 = 0; case 1: if (comContainsLabel[s][l1] && !comContainsLabel[s][l2]) combinationToLabelPair.get(s).add(position); break; // l1 = 0; l2 = 1; case 2: if (!comContainsLabel[s][l1] && comContainsLabel[s][l2]) combinationToLabelPair.get(s).add(position); break; // l1 = 1; l2 = 1; case 3: if (comContainsLabel[s][l1] && comContainsLabel[s][l2]) combinationToLabelPair.get(s).add(position); break; default: throw new RuntimeException("feature case :" + featureCase + " failed."); } } } /** * @return negative log-likelihood */ public double getValue() { if (isValueCacheValid) { return this.value; } this.value = getValueForAllData() + getPenalty(); this.isValueCacheValid = true; return this.value; } //check private double getPenalty() { Vector vector = cmlcrf.getWeights().getAllWeights(); double norm = (1-l1Ratio)*0.5*Math.pow(vector.norm(2),2) + l1Ratio*vector.norm(1); return norm * regularization; }//check private double getValueForAllData() { updateClassScoreMatrix(); updateAssignmentScoreMatrix(); IntStream intStream; if (isParallel) { intStream = IntStream.range(0,dataSet.getNumDataPoints()).parallel(); } else { intStream = IntStream.range(0,dataSet.getNumDataPoints()); } return intStream.mapToDouble(this::getValueForOneData).sum(); }//check // NLL private double getValueForOneData(int i) { double sum = 0.0; // sum logZ(x_n) sum += MathUtil.logSumExp(combScoreMatrix[i]); // score for the true combination sum -= combScoreMatrix[i][labelComIndices[i]]; return sum; }//check private void updateEmpiricalCounts(){ IntStream intStream; if (isParallel){ intStream = IntStream.range(0, numParameters).parallel(); } else { intStream = IntStream.range(0, numParameters); } intStream.forEach(this::calEmpiricalCount); } private void calEmpiricalCount(int parameterIndex) { if (parameterIndex < numWeightsForFeatures) { this.empiricalCounts.set(parameterIndex, calEmpiricalCountForFeature(parameterIndex)); } else if(parameterIndex <numWeightsForFeatures+ numWeightsForLabelPairs) { this.empiricalCounts.set(parameterIndex, calEmpiricalCountForLabelPair(parameterIndex)); } } private double calEmpiricalCountForLabelPair(int parameterIndex) { double empiricalCount = 0.0; int start = parameterIndex - numWeightsForFeatures; int l1 = parameterToL1[start]; int l2 = parameterToL2[start]; int featureCase = start % 4; for (int i=0; i<dataSet.getNumDataPoints(); i++) { MultiLabel label = dataSet.getMultiLabels()[i]; switch (featureCase) { // both l1, l2 equal 0; case 0: if (!label.matchClass(l1) && !label.matchClass(l2)) empiricalCount += 1.0; break; // l1 = 1; l2 = 0; case 1: if (label.matchClass(l1) && !label.matchClass(l2)) empiricalCount += 1.0; break; // l1 = 0; l2 = 1; case 2: if (!label.matchClass(l1) && label.matchClass(l2)) empiricalCount += 1.0; break; // l1 = 1; l2 = 1; case 3: if (label.matchClass(l1) && label.matchClass(l2)) empiricalCount += 1.0; break; default: throw new RuntimeException("feature case :" + featureCase + " failed."); } } return empiricalCount; } private double calEmpiricalCountForFeature(int parameterIndex) { double empiricalCount = 0.0; int classIndex = parameterToClass[parameterIndex]; int featureIndex = parameterToFeature[parameterIndex]; if (featureIndex==-1){ for (int i=0; i<dataSet.getNumDataPoints(); i++) { if (dataSet.getMultiLabels()[i].matchClass(classIndex)) { empiricalCount += 1; } } } else{ Vector column = dataSet.getColumn(featureIndex); MultiLabel[] multiLabels = dataSet.getMultiLabels(); for (Vector.Element element: column.nonZeroes()){ int dataIndex = element.index(); double featureValue = element.get(); if (multiLabels[dataIndex].matchClass(classIndex)){ empiricalCount += featureValue; } } } return empiricalCount; } private void updateClassProbMatrix(){ IntStream.range(0,dataSet.getNumDataPoints()).parallel() .forEach(i -> classProbMatrix[i] = cmlcrf.calClassProbs(combProbMatrix[i])); } private void updatePredictedCounts() { IntStream.range(0, numWeightsForFeatures).parallel() .forEach(i -> predictedCounts.set(i, calPredictedFeatureCounts(i))); IntStream.range(numWeightsForFeatures, numParameters).parallel() .forEach(i -> predictedCounts.set(i, calPredictedLabelPairCounts(i))); } private double calPredictedLabelPairCounts(int parameterIndex) { double count = 0.0; int pos = parameterIndex - numWeightsForFeatures; for (int matched : labelPairToCombination.get(pos)) { count += combProbSums[matched]; } return count; } private double calPredictedFeatureCounts(int parameterIndex) { double count = 0.0; int classIndex = parameterToClass[parameterIndex]; int featureIndex = parameterToFeature[parameterIndex]; if (featureIndex == -1) { for (int i=0; i<numData; i++) { count += this.classProbMatrix[i][classIndex]; } } else { Vector featureColumn = dataSet.getColumn(featureIndex); for (Vector.Element element : featureColumn.nonZeroes()) { int dataPointIndex = element.index(); double featureValue = element.get(); count += this.classProbMatrix[dataPointIndex][classIndex] * featureValue; } } return count; } private void updateCombProbSums(int combinationIndex){ double sum =0; for (int i=0;i<dataSet.getNumDataPoints();i++){ sum += combProbMatrix[i][combinationIndex]; } combProbSums[combinationIndex] = sum; } private void updateCombProbSums(){ IntStream.range(0,numSupport).parallel() .forEach(this::updateCombProbSums); } /** * a special back track line search for sufficient decrease with elasticnet penalized model * reference: * An improved glmnet for l1-regularized logistic regression. * @param searchDirection * @return */ private void lineSearch(Vector searchDirection, Vector gradient){ Vector localSearchDir; double initialStepLength = 1; double shrinkage = 0.5; double c = 1e-4; double stepLength = initialStepLength; Vector start = cmlcrf.getWeights().getAllWeights(); double penalty = getPenalty(); double value = getValue(); double product = gradient.dot(searchDirection); localSearchDir = searchDirection; while(true){ Vector step = localSearchDir.times(stepLength); Vector target = start.plus(step); cmlcrf.getWeights().setWeightVector(target); double targetPenalty = getPenalty(); double targetValue = getValue(); if (targetValue <= value + c*stepLength*(product + targetPenalty - penalty)){ break; } stepLength *= shrinkage; } } }