Java Code Examples for edu.stanford.nlp.stats.Counter#getCount()

The following examples show how to use edu.stanford.nlp.stats.Counter#getCount() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AdaGradFOBOSUpdater.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
public void updateL1(Counter<String> weights,
     Counter<String> gradient, int timeStep) {
  // w_{t+1} := w_t - nu*g_t
  for (String feature : gradient.keySet()) {
    double gValue = gradient.getCount(feature);
    double sgsValue = sumGradSquare.incrementCount(feature, gValue*gValue);
    double wValue = weights.getCount(feature);
    double currentrate = rate / (Math.sqrt(sgsValue)+eps);
    double testupdate = wValue - (currentrate * gValue);
    double realupdate = Math.signum(testupdate) * pospart( Math.abs(testupdate) - currentrate*this.lambda );
    if (realupdate == 0.0) {
      // Filter zeros
      weights.remove(feature);
    } else {
      weights.setCount(feature, realupdate);
    }
  }
}
 
Example 2
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Calculates the "informativeness" of each ngram, which is used by the NIST
 * metric. In Matlab notation, the informativeness of the ngram w_1:n is
 * defined as -log2(count(w_1:n)/count(w_1:n-1)).
 * 
 * @param ngramCounts
 *          ngram counts according to references
 * @param totWords
 *          total number of words, which is used to compute the
 *          informativeness of unigrams.
 */
static public <TK> Counter<Sequence<TK>> getNGramInfo(
    Counter<Sequence<TK>> ngramCounts, int totWords) {
  Counter<Sequence<TK>> ngramInfo = new ClassicCounter<Sequence<TK>>();

  for (Sequence<TK> ngram : ngramCounts.keySet()) {
    double num = ngramCounts.getCount(ngram);
    double denom = totWords;
    if (ngram.size() > 1) {
      Sequence<TK> ngramPrefix = ngram.subsequence(0,
          ngram.size() - 1);
      denom = ngramCounts.getCount(ngramPrefix);
    }
    double inf = -Math.log(num / denom) / LOG2;
    ngramInfo.setCount(ngram, inf);
    // System.err.printf("ngram info: %s %.3f\n", ngram.toString(), inf);
  }
  return ngramInfo;
}
 
Example 3
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Compute maximum n-gram counts from one or more sequences.
 * 
 * @param sequences - The list of sequences.
 * @param maxOrder - The n-gram order.
 */
static public <TK> Counter<Sequence<TK>> getMaxNGramCounts(
    List<Sequence<TK>> sequences, double[] seqWeights, int maxOrder) {
  Counter<Sequence<TK>> maxCounts = new ClassicCounter<Sequence<TK>>();
  maxCounts.setDefaultReturnValue(0.0);
  if(seqWeights != null && seqWeights.length != sequences.size()) {
    throw new RuntimeException("Improper weight vector for sequences.");
  }
  
  int seqId = 0;
  for (Sequence<TK> sequence : sequences) {
    Counter<Sequence<TK>> counts = getNGramCounts(sequence, maxOrder);
    for (Sequence<TK> ngram : counts.keySet()) {
      double weight = seqWeights == null ? 1.0 : seqWeights[seqId];
      double countValue = weight * counts.getCount(ngram);
      double currentMax = maxCounts.getCount(ngram);
      maxCounts.setCount(ngram, Math.max(countValue, currentMax));
    }
    ++seqId;
  }
  return maxCounts;
}
 
Example 4
Source File: NISTMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0) {
      int len = ngram.size();
      if (ngramInfo.containsKey(ngram))
        counts[len - 1] += cnt * ngramInfo.getCount(ngram);
      else
        System.err.println("Missing key for " + ngram.toString());
    }
  }
  return counts;
}
 
Example 5
Source File: SentencelevelBLEUVariance.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private static <TK> double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts, int order) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0.0) {
      int len = ngram.size();
      counts[len - 1] += cnt;
    }
  }

  return counts;
}
 
Example 6
Source File: ComputeBitextIDF.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * @param args
 */
public static void main(String[] args) {
  if (args.length > 0) {
    System.err.printf("Usage: java %s < files > idf-file%n", ComputeBitextIDF.class.getName());
    System.exit(-1);
  }

  Counter<String> documentsPerTerm = new ClassicCounter<String>(1000000);
  LineNumberReader reader = new LineNumberReader(new InputStreamReader(System.in));
  double nDocuments = 0.0;
  try {
    for (String line; (line = reader.readLine()) != null;) {
      String[] tokens = line.trim().split("\\s+");
      Set<String> seen = new HashSet<String>(tokens.length);
      for (String token : tokens) {
        if ( ! seen.contains(token)) {
          seen.add(token);
          documentsPerTerm.incrementCount(token);
        }
      }
    }
    nDocuments = reader.getLineNumber();
    reader.close();
  } catch (IOException e) {
    e.printStackTrace();
  }

  // Output the idfs
  System.err.printf("Bitext contains %d sentences and %d word types%n", (int) nDocuments, documentsPerTerm.keySet().size());
  for (String wordType : documentsPerTerm.keySet()) {
    double count = documentsPerTerm.getCount(wordType);
    System.out.printf("%s\t%f%n", wordType, Math.log(nDocuments / count));
  }
  System.out.printf("%s\t%f%n", UNK_TOKEN, Math.log(nDocuments / 1.0));
}
 
Example 7
Source File: PerceptronOptimizer.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);
  Counter<String> targetFeatures = MERT.summarizedAllFeaturesVector(target);
  Counter<String> wts = initialWts;

  while (true) {
    Scorer<String> scorer = new DenseScorer(wts, MERT.featureIndex);
    MultiTranslationMetricMax<IString, String> oneBestSearch = new HillClimbingMultiTranslationMetricMax<IString, String>(
        new ScorerWrapperEvaluationMetric<IString, String>(scorer));
    List<ScoredFeaturizedTranslation<IString, String>> oneBest = oneBestSearch
        .maximize(nbest);
    Counter<String> dir = MERT.summarizedAllFeaturesVector(oneBest);
    Counters.multiplyInPlace(dir, -1.0);
    dir.addAll(targetFeatures);
    Counter<String> newWts = mert.lineSearch(nbest, wts, dir, emetric);
    double ssd = 0;
    for (String k : newWts.keySet()) {
      double diff = wts.getCount(k) - newWts.getCount(k);
      ssd += diff * diff;
    }
    wts = newWts;
    if (ssd < MERT.NO_PROGRESS_SSD)
      break;
  }
  return wts;
}
 
Example 8
Source File: NISTMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private void initReferences(List<List<Sequence<TK>>> referencesList) {
  int listSz = referencesList.size();
  for (int listI = 0; listI < listSz; listI++) {
    List<Sequence<TK>> references = referencesList.get(listI);

    int refsSz = references.size();
    if (refsSz == 0) {
      throw new RuntimeException(String.format(
          "No references found for data point: %d\n", listI));
    }

    refLengths[listI] = new int[refsSz];
    Counter<Sequence<TK>> maxReferenceCount = MetricUtils.getMaxNGramCounts(
        references, order);
    maxReferenceCounts.add(maxReferenceCount);
    refLengths[listI][0] = references.get(0).size();

    for (int refI = 1; refI < refsSz; refI++) {
      refLengths[listI][refI] = references.get(refI).size();
      Counter<Sequence<TK>> altCounts = MetricUtils.getNGramCounts(
          references.get(refI), order);
      for (Sequence<TK> sequence : new HashSet<Sequence<TK>>(
          altCounts.keySet())) {
        double cnt = maxReferenceCount.getCount(sequence);
        double altCnt = altCounts.getCount(sequence);
        if (cnt < altCnt) {
          maxReferenceCount.setCount(sequence, altCnt);
        }
      }
    }
  }
}
 
Example 9
Source File: BLEUMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private static <TK> double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts, int order) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0.0) {
      int len = ngram.size();
      counts[len - 1] += cnt;
    }
  }

  return counts;
}
 
Example 10
Source File: MERT.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
static public double wtSsd(Counter<String> oldWts, Counter<String> newWts) {
  double ssd = 0;
  for (String k : newWts.keySet()) {
    double diff = oldWts.getCount(k) - newWts.getCount(k);
    ssd += diff * diff;
  }
  return ssd;
}
 
Example 11
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Add a scaled (positive) random vector to a weights vector.
 * 
 * @param wts
 * @param scale
 */
public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
  for (String feature : wts.keySet()) {
    double epsilon = Math.random() * scale;
    double newValue = wts.getCount(feature) + epsilon;
    wts.setCount(feature, newValue);
  }
}
 
Example 12
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
public static double scoreTranslation(Counter<String> wts,
    ScoredFeaturizedTranslation<IString, String> trans) {
  double s = 0;
  for (FeatureValue<String> fv : trans.features) {
    s += fv.value * wts.getCount(fv.name);
  }
  return s;
}
 
Example 13
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
public static double[] getWeightArrayFromCounter(String[] weightNames,
    Counter<String> wts) {
  double[] wtsArr = new double[weightNames.length];
  for (int i = 0; i < wtsArr.length; i++) {
    wtsArr[i] = wts.getCount(weightNames[i]);
  }
  return wtsArr;
}
 
Example 14
Source File: AbstractOnlineOptimizer.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Counter<String> getBatchGradient(Counter<String> weights,
    List<Sequence<IString>> sources, int[] sourceIds,
    List<List<RichTranslation<IString, String>>> translations,
    List<List<Sequence<IString>>> references,
    double[] referenceWeights,
    SentenceLevelMetric<IString, String> scoreMetric) {
  Counter<String> batchGradient = new ClassicCounter<String>();

  for (int i = 0; i < sourceIds.length; i++) {
    if (translations.get(i).size() > 0) {
      // Skip decoder failures.
      Counter<String> unregularizedGradient = getUnregularizedGradient(weights, sources.get(i), sourceIds[i], translations.get(i), references.get(i), referenceWeights, scoreMetric);
      batchGradient.addAll(unregularizedGradient);
    }
  }

  // Add L2 regularization directly into the derivative
  if (this.l2Regularization) {
    final Set<String> features = new HashSet<String>(weights.keySet());
    features.addAll(weights.keySet());
    final double dataFraction = sourceIds.length /(double) tuneSetSize;
    final double scaledInvSigmaSquared = dataFraction/(2*sigmaSq);
    for (String key : features) {
      double x = weights.getCount(key);
      batchGradient.incrementCount(key, x * scaledInvSigmaSquared);
    }
  }

  return batchGradient;
}
 
Example 15
Source File: Summarizer.java    From wiseowl with MIT License 4 votes vote down vote up
public Summarizer(Counter<String> dfCounter) {
  this.dfCounter = dfCounter;
  this.numDocuments = (int) dfCounter.getCount("__all__");
}
 
Example 16
Source File: CRFPostprocessor.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
/**
 * Evaluate the postprocessor given an input file specified in the flags.
 * 
 * @param preProcessor
 * @param pwOut
 */
protected void evaluate(Preprocessor preProcessor, PrintWriter pwOut) {
  System.err.println("Starting evaluation...");
  DocumentReaderAndWriter<CoreLabel> docReader = new ProcessorTools.PostprocessorDocumentReaderAndWriter(preProcessor);
  ObjectBank<List<CoreLabel>> lines =
    classifier.makeObjectBankFromFile(flags.testFile, docReader);

  Counter<String> labelTotal = new ClassicCounter<String>();
  Counter<String> labelCorrect = new ClassicCounter<String>();
  int total = 0;
  int correct = 0;
  PrintWriter pw = new PrintWriter(IOTools.getWriterFromFile("apply.out"));
  for (List<CoreLabel> line : lines) {
    line = classifier.classify(line);
    pw.println(Sentence.listToString(ProcessorTools.toPostProcessedSequence(line)));
    total += line.size();
    for (CoreLabel label : line) {
      String hypothesis = label.get(CoreAnnotations.AnswerAnnotation.class);
      String reference = label.get(CoreAnnotations.GoldAnswerAnnotation.class);
      labelTotal.incrementCount(reference);
      if (hypothesis.equals(reference)) {
        correct++;
        labelCorrect.incrementCount(reference);
      }
    }
  }
  pw.close();

  double accuracy = ((double) correct) / ((double) total);
  accuracy *= 100.0;

  pwOut.println("EVALUATION RESULTS");
  pwOut.printf("#datums:\t%d%n", total);
  pwOut.printf("#correct:\t%d%n", correct);
  pwOut.printf("accuracy:\t%.2f%n", accuracy);
  pwOut.println("==================");

  // Output the per label accuracies
  pwOut.println("PER LABEL ACCURACIES");
  for (String refLabel : labelTotal.keySet()) {
    double nTotal = labelTotal.getCount(refLabel);
    double nCorrect = labelCorrect.getCount(refLabel);
    double acc = (nCorrect / nTotal) * 100.0;
    pwOut.printf(" %s\t%.2f%n", refLabel, acc);
  }
}
 
Example 17
Source File: SoftmaxMaxMarginSlackRescaling.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);

  EvaluationMetric<IString, String> modelMetric = new LinearCombinationMetric<IString, String>(
      new double[] { 1.0 },
      new ScorerWrapperEvaluationMetric<IString, String>(new DenseScorer(
          initialWts)));

  List<ScoredFeaturizedTranslation<IString, String>> current = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      modelMetric)).maximize(nbest);

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);

  System.err.println("Target model: " + modelMetric.score(target)
      + " metric: " + emetric.score(target));
  System.err.println("Current model: " + modelMetric.score(current)
      + " metric: " + emetric.score(current));

  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[wts.size()];
  double[] initialWtsArr = new double[wts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  double[][] lossMatrix = OptimizerUtils.calcDeltaMetric(nbest, target,
      emetric);

  // scale local relative loss by dataset size x 100
  // loss is then on a per sentence
  // BLEU 0-100 scale rather than BLEU 0-1.0
  for (int i = 0; i < lossMatrix.length; i++) {
    for (int j = 0; j < lossMatrix[i].length; j++) {
      lossMatrix[i][j] *= lossMatrix.length * 100;
    }
  }

  double lossSum = 0, lossMax = Double.NEGATIVE_INFINITY;
  int lossCnt = 0;

  for (int i = 0; i < lossMatrix.length; i++) {
    for (int j = 0; j < lossMatrix[i].length; j++) {
      lossCnt++;
      lossSum += lossMatrix[i][j];
      if (lossMatrix[i][j] > lossMax)
        lossMax = lossMatrix[i][j];
    }
  }
  System.err.printf("Loss Avg: %e\n", lossSum / lossCnt);
  System.err.printf("Loss Max: %e\n", lossMax);

  Minimizer<DiffFunction> qn = new QNMinimizer(15, true);
  SoftMaxMarginSlackRescaling sm3n = new SoftMaxMarginSlackRescaling(
      weightNames, target, lossMatrix);
  double initialValueAt = sm3n.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  double initialObjValue = sm3n.valueAt(initialWtsArr);
  double initalDNorm = OptimizerUtils.norm2DoubleArray(sm3n
      .derivativeAt(initialWtsArr));
  double initalXNorm = OptimizerUtils.norm2DoubleArray(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = qn.minimize(sm3n, 1e-4, initialWtsArr); // new
                                                          // double[wts.size()]
  Counter<String> newWts = OptimizerUtils.getWeightCounterFromArray(
      weightNames, newX);
  double finalObjValue = sm3n.valueAt(newX);

  double objDiff = initialObjValue - finalObjValue;
  double finalDNorm = OptimizerUtils
      .norm2DoubleArray(sm3n.derivativeAt(newX));
  double finalXNorm = OptimizerUtils.norm2DoubleArray(newX);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  System.err.println(">>>[Converge Info] ObjInit(" + initialObjValue
      + ") - ObjFinal(" + finalObjValue + ") = ObjDiff(" + objDiff
      + ") L2DInit(" + initalDNorm + ") L2DFinal(" + finalDNorm
      + ") L2XInit(" + initalXNorm + ") L2XFinal(" + finalXNorm + ")");

  MERT.updateBest(newWts, metricEval, true);

  return newWts;
}
 
Example 18
Source File: SoftmaxMaxMarginMarkovNetwork.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);

  EvaluationMetric<IString, String> modelMetric = new LinearCombinationMetric<IString, String>(
      new double[] { 1.0 },
      new ScorerWrapperEvaluationMetric<IString, String>(new DenseScorer(
          initialWts)));

  List<ScoredFeaturizedTranslation<IString, String>> current = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      modelMetric)).maximize(nbest);

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);

  System.err.println("Target model: " + modelMetric.score(target)
      + " metric: " + emetric.score(target));
  System.err.println("Current model: " + modelMetric.score(current)
      + " metric: " + emetric.score(current));

  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[wts.size()];
  double[] initialWtsArr = new double[wts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  double[][] lossMatrix = OptimizerUtils.calcDeltaMetric(nbest, target,
      emetric);

  Minimizer<DiffFunction> qn = new QNMinimizer(15, true);
  SoftMaxMarginMarkovNetwork sm3n = new SoftMaxMarginMarkovNetwork(
      weightNames, target, lossMatrix);
  double initialValueAt = sm3n.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  double initialObjValue = sm3n.valueAt(initialWtsArr);
  double initalDNorm = OptimizerUtils.norm2DoubleArray(sm3n
      .derivativeAt(initialWtsArr));
  double initalXNorm = OptimizerUtils.norm2DoubleArray(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = qn.minimize(sm3n, 1e-4, initialWtsArr); // new
                                                          // double[wts.size()]
  Counter<String> newWts = OptimizerUtils.getWeightCounterFromArray(
      weightNames, newX);
  double finalObjValue = sm3n.valueAt(newX);

  double objDiff = initialObjValue - finalObjValue;
  double finalDNorm = OptimizerUtils
      .norm2DoubleArray(sm3n.derivativeAt(newX));
  double finalXNorm = OptimizerUtils.norm2DoubleArray(newX);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  System.err.println(">>>[Converge Info] ObjInit(" + initialObjValue
      + ") - ObjFinal(" + finalObjValue + ") = ObjDiff(" + objDiff
      + ") L2DInit(" + initalDNorm + ") L2DFinal(" + finalDNorm
      + ") L2XInit(" + initalXNorm + ") L2XFinal(" + finalXNorm + ")");

  MERT.updateBest(newWts, metricEval, true);

  return newWts;
}
 
Example 19
Source File: RandomAltPairs.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  System.err.printf("RandomAltPairs forceBetter = %b\n", forceBetter);
  Counter<String> wts = initialWts;

  for (int noProgress = 0; noProgress < MERT.NO_PROGRESS_LIMIT;) {
    Counter<String> dir;
    List<ScoredFeaturizedTranslation<IString, String>> rTrans;
    Scorer<String> scorer = new DenseScorer(wts, MERT.featureIndex);

    dir = MERT.summarizedAllFeaturesVector(rTrans = (forceBetter ? mert
        .randomBetterTranslations(nbest, wts, emetric) : mert
        .randomTranslations(nbest)));
    Counter<String> newWts1 = mert.lineSearch(nbest, wts, dir, emetric); // search toward random better translation
          
    MultiTranslationMetricMax<IString, String> oneBestSearch = new HillClimbingMultiTranslationMetricMax<IString, String>(
        new ScorerWrapperEvaluationMetric<IString, String>(scorer));
    List<ScoredFeaturizedTranslation<IString, String>> oneBest = oneBestSearch
        .maximize(nbest);
    
    Counters.subtractInPlace(dir, wts);

    System.err.printf("Random alternate score: %.5f \n",
        emetric.score(rTrans));

    Counter<String> newWts = mert.lineSearch(nbest, newWts1, dir, emetric);
    double eval = MERT.evalAtPoint(nbest, newWts, emetric);

    double ssd = 0;
    for (String k : newWts.keySet()) {
      double diff = wts.getCount(k) - newWts.getCount(k);
      ssd += diff * diff;
    }
    System.err.printf("Eval: %.5f SSD: %e (no progress: %d)\n", eval, ssd,
        noProgress);
    wts = newWts;
    if (ssd < MERT.NO_PROGRESS_SSD)
      noProgress++;
    else
      noProgress = 0;
  }
  return wts;
}
 
Example 20
Source File: DownhillSimplexOptimizer.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);
      
  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[initialWts.size()];
  double[] initialWtsArr = new double[initialWts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  Minimizer<Function> dhsm = new DownhillSimplexMinimizer();

  MERTObjective mo = new MERTObjective(weightNames);
  
  double initialValueAt = mo.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  
  double initialObjValue = mo.valueAt(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = dhsm.minimize(mo, 1e-6, initialWtsArr); // new
                                                         // double[wts.size()]

  Counter<String> newWts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    newWts.setCount(weightNames[i], newX[i]);
  }

  double finalObjValue = mo.valueAt(newX);
  
  System.err.println("Final Objective value: " + finalObjValue);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  MERT.updateBest(newWts, metricEval);
  return newWts;
}