Java Code Examples for edu.stanford.nlp.stats.Counter#keySet()

The following examples show how to use edu.stanford.nlp.stats.Counter#keySet() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Calculates the "informativeness" of each ngram, which is used by the NIST
 * metric. In Matlab notation, the informativeness of the ngram w_1:n is
 * defined as -log2(count(w_1:n)/count(w_1:n-1)).
 * 
 * @param ngramCounts
 *          ngram counts according to references
 * @param totWords
 *          total number of words, which is used to compute the
 *          informativeness of unigrams.
 */
static public <TK> Counter<Sequence<TK>> getNGramInfo(
    Counter<Sequence<TK>> ngramCounts, int totWords) {
  Counter<Sequence<TK>> ngramInfo = new ClassicCounter<Sequence<TK>>();

  for (Sequence<TK> ngram : ngramCounts.keySet()) {
    double num = ngramCounts.getCount(ngram);
    double denom = totWords;
    if (ngram.size() > 1) {
      Sequence<TK> ngramPrefix = ngram.subsequence(0,
          ngram.size() - 1);
      denom = ngramCounts.getCount(ngramPrefix);
    }
    double inf = -Math.log(num / denom) / LOG2;
    ngramInfo.setCount(ngram, inf);
    // System.err.printf("ngram info: %s %.3f\n", ngram.toString(), inf);
  }
  return ngramInfo;
}
 
Example 2
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Compute maximum n-gram counts from one or more sequences.
 * 
 * @param sequences - The list of sequences.
 * @param maxOrder - The n-gram order.
 */
static public <TK> Counter<Sequence<TK>> getMaxNGramCounts(
    List<Sequence<TK>> sequences, double[] seqWeights, int maxOrder) {
  Counter<Sequence<TK>> maxCounts = new ClassicCounter<Sequence<TK>>();
  maxCounts.setDefaultReturnValue(0.0);
  if(seqWeights != null && seqWeights.length != sequences.size()) {
    throw new RuntimeException("Improper weight vector for sequences.");
  }
  
  int seqId = 0;
  for (Sequence<TK> sequence : sequences) {
    Counter<Sequence<TK>> counts = getNGramCounts(sequence, maxOrder);
    for (Sequence<TK> ngram : counts.keySet()) {
      double weight = seqWeights == null ? 1.0 : seqWeights[seqId];
      double countValue = weight * counts.getCount(ngram);
      double currentMax = maxCounts.getCount(ngram);
      maxCounts.setCount(ngram, Math.max(countValue, currentMax));
    }
    ++seqId;
  }
  return maxCounts;
}
 
Example 3
Source File: AdaGradFOBOSUpdater.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
public void updateL1(Counter<String> weights,
     Counter<String> gradient, int timeStep) {
  // w_{t+1} := w_t - nu*g_t
  for (String feature : gradient.keySet()) {
    double gValue = gradient.getCount(feature);
    double sgsValue = sumGradSquare.incrementCount(feature, gValue*gValue);
    double wValue = weights.getCount(feature);
    double currentrate = rate / (Math.sqrt(sgsValue)+eps);
    double testupdate = wValue - (currentrate * gValue);
    double realupdate = Math.signum(testupdate) * pospart( Math.abs(testupdate) - currentrate*this.lambda );
    if (realupdate == 0.0) {
      // Filter zeros
      weights.remove(feature);
    } else {
      weights.setCount(feature, realupdate);
    }
  }
}
 
Example 4
Source File: MERT.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
static public double l1norm(Counter<String> wts) {
  double sum = 0;
  for (String f : wts.keySet()) {
    sum += Math.abs(wts.getCount(f));
  }

  return sum;
}
 
Example 5
Source File: ComputeBitextIDF.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * @param args
 */
public static void main(String[] args) {
  if (args.length > 0) {
    System.err.printf("Usage: java %s < files > idf-file%n", ComputeBitextIDF.class.getName());
    System.exit(-1);
  }

  Counter<String> documentsPerTerm = new ClassicCounter<String>(1000000);
  LineNumberReader reader = new LineNumberReader(new InputStreamReader(System.in));
  double nDocuments = 0.0;
  try {
    for (String line; (line = reader.readLine()) != null;) {
      String[] tokens = line.trim().split("\\s+");
      Set<String> seen = new HashSet<String>(tokens.length);
      for (String token : tokens) {
        if ( ! seen.contains(token)) {
          seen.add(token);
          documentsPerTerm.incrementCount(token);
        }
      }
    }
    nDocuments = reader.getLineNumber();
    reader.close();
  } catch (IOException e) {
    e.printStackTrace();
  }

  // Output the idfs
  System.err.printf("Bitext contains %d sentences and %d word types%n", (int) nDocuments, documentsPerTerm.keySet().size());
  for (String wordType : documentsPerTerm.keySet()) {
    double count = documentsPerTerm.getCount(wordType);
    System.out.printf("%s\t%f%n", wordType, Math.log(nDocuments / count));
  }
  System.out.printf("%s\t%f%n", UNK_TOKEN, Math.log(nDocuments / 1.0));
}
 
Example 6
Source File: NISTMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0) {
      int len = ngram.size();
      if (ngramInfo.containsKey(ngram))
        counts[len - 1] += cnt * ngramInfo.getCount(ngram);
      else
        System.err.println("Missing key for " + ngram.toString());
    }
  }
  return counts;
}
 
Example 7
Source File: NISTMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private void initReferences(List<List<Sequence<TK>>> referencesList) {
  int listSz = referencesList.size();
  for (int listI = 0; listI < listSz; listI++) {
    List<Sequence<TK>> references = referencesList.get(listI);

    int refsSz = references.size();
    if (refsSz == 0) {
      throw new RuntimeException(String.format(
          "No references found for data point: %d\n", listI));
    }

    refLengths[listI] = new int[refsSz];
    Counter<Sequence<TK>> maxReferenceCount = MetricUtils.getMaxNGramCounts(
        references, order);
    maxReferenceCounts.add(maxReferenceCount);
    refLengths[listI][0] = references.get(0).size();

    for (int refI = 1; refI < refsSz; refI++) {
      refLengths[listI][refI] = references.get(refI).size();
      Counter<Sequence<TK>> altCounts = MetricUtils.getNGramCounts(
          references.get(refI), order);
      for (Sequence<TK> sequence : new HashSet<Sequence<TK>>(
          altCounts.keySet())) {
        double cnt = maxReferenceCount.getCount(sequence);
        double altCnt = altCounts.getCount(sequence);
        if (cnt < altCnt) {
          maxReferenceCount.setCount(sequence, altCnt);
        }
      }
    }
  }
}
 
Example 8
Source File: BLEUMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private static <TK> double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts, int order) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0.0) {
      int len = ngram.size();
      counts[len - 1] += cnt;
    }
  }

  return counts;
}
 
Example 9
Source File: MERT.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
static public double wtSsd(Counter<String> oldWts, Counter<String> newWts) {
  double ssd = 0;
  for (String k : newWts.keySet()) {
    double diff = oldWts.getCount(k) - newWts.getCount(k);
    ssd += diff * diff;
  }
  return ssd;
}
 
Example 10
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Add a scaled (positive) random vector to a weights vector.
 * 
 * @param wts
 * @param scale
 */
public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
  for (String feature : wts.keySet()) {
    double epsilon = Math.random() * scale;
    double newValue = wts.getCount(feature) + epsilon;
    wts.setCount(feature, newValue);
  }
}
 
Example 11
Source File: PerceptronOptimizer.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);
  Counter<String> targetFeatures = MERT.summarizedAllFeaturesVector(target);
  Counter<String> wts = initialWts;

  while (true) {
    Scorer<String> scorer = new DenseScorer(wts, MERT.featureIndex);
    MultiTranslationMetricMax<IString, String> oneBestSearch = new HillClimbingMultiTranslationMetricMax<IString, String>(
        new ScorerWrapperEvaluationMetric<IString, String>(scorer));
    List<ScoredFeaturizedTranslation<IString, String>> oneBest = oneBestSearch
        .maximize(nbest);
    Counter<String> dir = MERT.summarizedAllFeaturesVector(oneBest);
    Counters.multiplyInPlace(dir, -1.0);
    dir.addAll(targetFeatures);
    Counter<String> newWts = mert.lineSearch(nbest, wts, dir, emetric);
    double ssd = 0;
    for (String k : newWts.keySet()) {
      double diff = wts.getCount(k) - newWts.getCount(k);
      ssd += diff * diff;
    }
    wts = newWts;
    if (ssd < MERT.NO_PROGRESS_SSD)
      break;
  }
  return wts;
}
 
Example 12
Source File: PairwiseRankingOptimizer.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);
  Counters.normalize(wts);
  double seedSeed = Math.abs(Counters.max(wts));
  long seed = (long)Math.exp(Math.log(seedSeed) + Math.log(Long.MAX_VALUE));
  System.err.printf("PRO thread using random seed: %d\n", seed);
  RVFDataset<String, String> proSamples = getSamples(new Random(seed));
  LogPrior lprior = new LogPrior();
  lprior.setSigma(l2sigma);
  LogisticClassifierFactory<String,String> lcf = new LogisticClassifierFactory<String,String>();
  LogisticClassifier<String, String> lc = lcf.trainClassifier(proSamples, lprior, false);
  Counter<String> decoderWeights = new ClassicCounter<String>(); 
  Counter<String> lcWeights = lc.weightsAsCounter();
  for (String key : lcWeights.keySet()) {
    double mul;
    if (key.startsWith("1 / ")) {
      mul = 1.0;
    } else if (key.startsWith("0 / ")) {
      mul = -1.0;
    } else {
      throw new RuntimeException("Unparsable weight name produced by logistic classifier: "+key);
    }
    String decoderKey = key.replaceFirst("^[10] / ", "");
    decoderWeights.incrementCount(decoderKey, mul*lcWeights.getCount(key));
  }

  synchronized (MERT.bestWts) {
    if (!updatedBestOnce) {
      System.err.println("Force updating weights (once)");
      double metricEval = MERT.evalAtPoint(nbest, decoderWeights, emetric);
      MERT.updateBest(decoderWeights, metricEval, true);
      updatedBestOnce = true;
    }
  }
  return decoderWeights;
}
 
Example 13
Source File: PowellOptimizer.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
public Counter<String> optimize(Counter<String> initialWts) {

  Counter<String> wts = initialWts;

  // initialize search directions
  List<Counter<String>> dirs = new ArrayList<Counter<String>>(
      initialWts.size());
  List<String> featureNames = new ArrayList<String>(wts.keySet());
  Collections.sort(featureNames);
  for (String featureName : featureNames) {
    Counter<String> dir = new ClassicCounter<String>();
    dir.incrementCount(featureName);
    dirs.add(dir);
  }

  // main optimization loop
  Counter[] p = new ClassicCounter[dirs.size()];
  double objValue = MERT.evalAtPoint(nbest, wts, emetric); // obj value w/o
  // smoothing
  for (int iter = 0;; iter++) {
    // search along each direction
    p[0] = mert.lineSearch(nbest, wts, dirs.get(0), emetric);
    double eval = MERT.evalAtPoint(nbest, p[0], emetric);
    double biggestWin = Math.max(0, eval - objValue);
    System.err.printf("initial totalWin: %e (%e-%e)\n", biggestWin, eval,
        objValue);
    System.err.printf("apply @ wts: %e\n",
        MERT.evalAtPoint(nbest, wts, emetric));
    System.err.printf("apply @ p[0]: %e\n",
        MERT.evalAtPoint(nbest, p[0], emetric));
    objValue = eval;
    int biggestWinId = 0;
    double totalWin = biggestWin;
    double initObjValue = objValue;
    for (int i = 1; i < p.length; i++) {
      p[i] = mert.lineSearch(nbest, (Counter<String>) p[i - 1], dirs.get(i),
          emetric);
      eval = MERT.evalAtPoint(nbest, p[i], emetric);
      if (Math.max(0, eval - objValue) > biggestWin) {
        biggestWin = eval - objValue;
        biggestWinId = i;
      }
      totalWin += Math.max(0, eval - objValue);
      System.err.printf("\t%d totalWin: %e(%e-%e)\n", i, totalWin, eval,
          objValue);
      objValue = eval;
    }

    System.err.printf("%d: totalWin %e biggestWin: %e objValue: %e\n", iter,
        totalWin, biggestWin, objValue);

    // construct combined direction
    Counter<String> combinedDir = new ClassicCounter<String>(wts);
    Counters.multiplyInPlace(combinedDir, -1.0);
    combinedDir.addAll(p[p.length - 1]);

    // check to see if we should replace the dominant 'win' direction
    // during the last iteration of search with the combined search direction
    Counter<String> testPoint = new ClassicCounter<String>(p[p.length - 1]);
    testPoint.addAll(combinedDir);
    double testPointEval = MERT.evalAtPoint(nbest, testPoint, emetric);
    double extrapolatedWin = testPointEval - objValue;
    System.err.printf("Test Point Eval: %e, extrapolated win: %e\n",
        testPointEval, extrapolatedWin);
    if (extrapolatedWin > 0
        && 2 * (2 * totalWin - extrapolatedWin)
            * Math.pow(totalWin - biggestWin, 2.0) < Math.pow(
            extrapolatedWin, 2.0) * biggestWin) {
      System.err.printf(
          "%d: updating direction %d with combined search dir\n", iter,
          biggestWinId);
      MERT.normalize(combinedDir);
      dirs.set(biggestWinId, combinedDir);
    }

    // Search along combined dir even if replacement didn't happen
    wts = mert.lineSearch(nbest, p[p.length - 1], combinedDir, emetric);
    eval = MERT.evalAtPoint(nbest, wts, emetric);
    System.err.printf(
        "%d: Objective after combined search (gain: %e prior:%e)\n", iter,
        eval - objValue, objValue);

    objValue = eval;

    double finalObjValue = objValue;
    System.err.printf("Actual win: %e (%e-%e)\n", finalObjValue
        - initObjValue, finalObjValue, initObjValue);
    if (Math.abs(initObjValue - finalObjValue) < MERT.MIN_OBJECTIVE_DIFF)
      break; // changed to prevent infinite loops
  }

  return wts;
}
 
Example 14
Source File: SoftmaxMaxMarginMarkovNetwork.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);

  EvaluationMetric<IString, String> modelMetric = new LinearCombinationMetric<IString, String>(
      new double[] { 1.0 },
      new ScorerWrapperEvaluationMetric<IString, String>(new DenseScorer(
          initialWts)));

  List<ScoredFeaturizedTranslation<IString, String>> current = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      modelMetric)).maximize(nbest);

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);

  System.err.println("Target model: " + modelMetric.score(target)
      + " metric: " + emetric.score(target));
  System.err.println("Current model: " + modelMetric.score(current)
      + " metric: " + emetric.score(current));

  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[wts.size()];
  double[] initialWtsArr = new double[wts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  double[][] lossMatrix = OptimizerUtils.calcDeltaMetric(nbest, target,
      emetric);

  Minimizer<DiffFunction> qn = new QNMinimizer(15, true);
  SoftMaxMarginMarkovNetwork sm3n = new SoftMaxMarginMarkovNetwork(
      weightNames, target, lossMatrix);
  double initialValueAt = sm3n.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  double initialObjValue = sm3n.valueAt(initialWtsArr);
  double initalDNorm = OptimizerUtils.norm2DoubleArray(sm3n
      .derivativeAt(initialWtsArr));
  double initalXNorm = OptimizerUtils.norm2DoubleArray(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = qn.minimize(sm3n, 1e-4, initialWtsArr); // new
                                                          // double[wts.size()]
  Counter<String> newWts = OptimizerUtils.getWeightCounterFromArray(
      weightNames, newX);
  double finalObjValue = sm3n.valueAt(newX);

  double objDiff = initialObjValue - finalObjValue;
  double finalDNorm = OptimizerUtils
      .norm2DoubleArray(sm3n.derivativeAt(newX));
  double finalXNorm = OptimizerUtils.norm2DoubleArray(newX);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  System.err.println(">>>[Converge Info] ObjInit(" + initialObjValue
      + ") - ObjFinal(" + finalObjValue + ") = ObjDiff(" + objDiff
      + ") L2DInit(" + initalDNorm + ") L2DFinal(" + finalDNorm
      + ") L2XInit(" + initalXNorm + ") L2XFinal(" + finalXNorm + ")");

  MERT.updateBest(newWts, metricEval, true);

  return newWts;
}
 
Example 15
Source File: SoftmaxMaxMarginSlackRescaling.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);

  EvaluationMetric<IString, String> modelMetric = new LinearCombinationMetric<IString, String>(
      new double[] { 1.0 },
      new ScorerWrapperEvaluationMetric<IString, String>(new DenseScorer(
          initialWts)));

  List<ScoredFeaturizedTranslation<IString, String>> current = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      modelMetric)).maximize(nbest);

  List<ScoredFeaturizedTranslation<IString, String>> target = (new HillClimbingMultiTranslationMetricMax<IString, String>(
      emetric)).maximize(nbest);

  System.err.println("Target model: " + modelMetric.score(target)
      + " metric: " + emetric.score(target));
  System.err.println("Current model: " + modelMetric.score(current)
      + " metric: " + emetric.score(current));

  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[wts.size()];
  double[] initialWtsArr = new double[wts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  double[][] lossMatrix = OptimizerUtils.calcDeltaMetric(nbest, target,
      emetric);

  // scale local relative loss by dataset size x 100
  // loss is then on a per sentence
  // BLEU 0-100 scale rather than BLEU 0-1.0
  for (int i = 0; i < lossMatrix.length; i++) {
    for (int j = 0; j < lossMatrix[i].length; j++) {
      lossMatrix[i][j] *= lossMatrix.length * 100;
    }
  }

  double lossSum = 0, lossMax = Double.NEGATIVE_INFINITY;
  int lossCnt = 0;

  for (int i = 0; i < lossMatrix.length; i++) {
    for (int j = 0; j < lossMatrix[i].length; j++) {
      lossCnt++;
      lossSum += lossMatrix[i][j];
      if (lossMatrix[i][j] > lossMax)
        lossMax = lossMatrix[i][j];
    }
  }
  System.err.printf("Loss Avg: %e\n", lossSum / lossCnt);
  System.err.printf("Loss Max: %e\n", lossMax);

  Minimizer<DiffFunction> qn = new QNMinimizer(15, true);
  SoftMaxMarginSlackRescaling sm3n = new SoftMaxMarginSlackRescaling(
      weightNames, target, lossMatrix);
  double initialValueAt = sm3n.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  double initialObjValue = sm3n.valueAt(initialWtsArr);
  double initalDNorm = OptimizerUtils.norm2DoubleArray(sm3n
      .derivativeAt(initialWtsArr));
  double initalXNorm = OptimizerUtils.norm2DoubleArray(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = qn.minimize(sm3n, 1e-4, initialWtsArr); // new
                                                          // double[wts.size()]
  Counter<String> newWts = OptimizerUtils.getWeightCounterFromArray(
      weightNames, newX);
  double finalObjValue = sm3n.valueAt(newX);

  double objDiff = initialObjValue - finalObjValue;
  double finalDNorm = OptimizerUtils
      .norm2DoubleArray(sm3n.derivativeAt(newX));
  double finalXNorm = OptimizerUtils.norm2DoubleArray(newX);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  System.err.println(">>>[Converge Info] ObjInit(" + initialObjValue
      + ") - ObjFinal(" + finalObjValue + ") = ObjDiff(" + objDiff
      + ") L2DInit(" + initalDNorm + ") L2DFinal(" + finalDNorm
      + ") L2XInit(" + initalXNorm + ") L2XFinal(" + finalXNorm + ")");

  MERT.updateBest(newWts, metricEval, true);

  return newWts;
}
 
Example 16
Source File: MERT.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
static Counter<String> removeWts(Counter<String> wts, Counter<String> fixedWts) {
  if (fixedWts != null)
    for (String s : fixedWts.keySet())
      wts.remove(s);
  return wts;
}
 
Example 17
Source File: DownhillSimplexOptimizer.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);
      
  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[initialWts.size()];
  double[] initialWtsArr = new double[initialWts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  Minimizer<Function> dhsm = new DownhillSimplexMinimizer();

  MERTObjective mo = new MERTObjective(weightNames);
  
  double initialValueAt = mo.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  
  double initialObjValue = mo.valueAt(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = dhsm.minimize(mo, 1e-6, initialWtsArr); // new
                                                         // double[wts.size()]

  Counter<String> newWts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    newWts.setCount(weightNames[i], newX[i]);
  }

  double finalObjValue = mo.valueAt(newX);
  
  System.err.println("Final Objective value: " + finalObjValue);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  MERT.updateBest(newWts, metricEval);
  return newWts;
}
 
Example 18
Source File: BasicPowellOptimizer.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = initialWts;

  // initialize search directions
  List<Counter<String>> axisDirs = new ArrayList<Counter<String>>(
      initialWts.size());
  List<String> featureNames = new ArrayList<String>(wts.keySet());
  Collections.sort(featureNames);
  for (String featureName : featureNames) {
    Counter<String> dir = new ClassicCounter<String>();
    dir.incrementCount(featureName);
    axisDirs.add(dir);
  }

  // main optimization loop
  Counter[] p = new ClassicCounter[axisDirs.size()];
  double objValue = MERT.evalAtPoint(nbest, wts, emetric); // obj value w/o
  // smoothing
  List<Counter<String>> dirs = null;
  for (int iter = 0;; iter++) {
    if (iter % p.length == 0) {
      // reset after N iterations to avoid linearly dependent search
      // directions
      System.err.printf("%d: Search direction reset\n", iter);
      dirs = new ArrayList<Counter<String>>(axisDirs);
    }
    // search along each direction
    assert (dirs != null);
    p[0] = mert.lineSearch(nbest, wts, dirs.get(0), emetric);
    for (int i = 1; i < p.length; i++) {
      p[i] = mert.lineSearch(nbest, (Counter<String>) p[i - 1], dirs.get(i),
          emetric);
      dirs.set(i - 1, dirs.get(i)); // shift search directions
    }

    double totalWin = MERT.evalAtPoint(nbest, p[p.length - 1], emetric)
        - objValue;
    System.err.printf("%d: totalWin: %e Objective: %e\n", iter, totalWin,
        objValue);
    if (Math.abs(totalWin) < MERT.MIN_OBJECTIVE_DIFF)
      break;

    // construct combined direction
    Counter<String> combinedDir = new ClassicCounter<String>(wts);
    Counters.multiplyInPlace(combinedDir, -1.0);
    combinedDir.addAll(p[p.length - 1]);

    dirs.set(p.length - 1, combinedDir);

    // search along combined direction
    wts = mert.lineSearch(nbest, (Counter<String>) p[p.length - 1],
        dirs.get(p.length - 1), emetric);
    objValue = MERT.evalAtPoint(nbest, wts, emetric);
    System.err.printf("%d: Objective after combined search %e\n", iter,
        objValue);
  }

  return wts;
}
 
Example 19
Source File: CRFPostprocessor.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
/**
 * Evaluate the postprocessor given an input file specified in the flags.
 * 
 * @param preProcessor
 * @param pwOut
 */
protected void evaluate(Preprocessor preProcessor, PrintWriter pwOut) {
  System.err.println("Starting evaluation...");
  DocumentReaderAndWriter<CoreLabel> docReader = new ProcessorTools.PostprocessorDocumentReaderAndWriter(preProcessor);
  ObjectBank<List<CoreLabel>> lines =
    classifier.makeObjectBankFromFile(flags.testFile, docReader);

  Counter<String> labelTotal = new ClassicCounter<String>();
  Counter<String> labelCorrect = new ClassicCounter<String>();
  int total = 0;
  int correct = 0;
  PrintWriter pw = new PrintWriter(IOTools.getWriterFromFile("apply.out"));
  for (List<CoreLabel> line : lines) {
    line = classifier.classify(line);
    pw.println(Sentence.listToString(ProcessorTools.toPostProcessedSequence(line)));
    total += line.size();
    for (CoreLabel label : line) {
      String hypothesis = label.get(CoreAnnotations.AnswerAnnotation.class);
      String reference = label.get(CoreAnnotations.GoldAnswerAnnotation.class);
      labelTotal.incrementCount(reference);
      if (hypothesis.equals(reference)) {
        correct++;
        labelCorrect.incrementCount(reference);
      }
    }
  }
  pw.close();

  double accuracy = ((double) correct) / ((double) total);
  accuracy *= 100.0;

  pwOut.println("EVALUATION RESULTS");
  pwOut.printf("#datums:\t%d%n", total);
  pwOut.printf("#correct:\t%d%n", correct);
  pwOut.printf("accuracy:\t%.2f%n", accuracy);
  pwOut.println("==================");

  // Output the per label accuracies
  pwOut.println("PER LABEL ACCURACIES");
  for (String refLabel : labelTotal.keySet()) {
    double nTotal = labelTotal.getCount(refLabel);
    double nCorrect = labelCorrect.getCount(refLabel);
    double acc = (nCorrect / nTotal) * 100.0;
    pwOut.printf(" %s\t%.2f%n", refLabel, acc);
  }
}
 
Example 20
Source File: RandomAltPairs.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  System.err.printf("RandomAltPairs forceBetter = %b\n", forceBetter);
  Counter<String> wts = initialWts;

  for (int noProgress = 0; noProgress < MERT.NO_PROGRESS_LIMIT;) {
    Counter<String> dir;
    List<ScoredFeaturizedTranslation<IString, String>> rTrans;
    Scorer<String> scorer = new DenseScorer(wts, MERT.featureIndex);

    dir = MERT.summarizedAllFeaturesVector(rTrans = (forceBetter ? mert
        .randomBetterTranslations(nbest, wts, emetric) : mert
        .randomTranslations(nbest)));
    Counter<String> newWts1 = mert.lineSearch(nbest, wts, dir, emetric); // search toward random better translation
          
    MultiTranslationMetricMax<IString, String> oneBestSearch = new HillClimbingMultiTranslationMetricMax<IString, String>(
        new ScorerWrapperEvaluationMetric<IString, String>(scorer));
    List<ScoredFeaturizedTranslation<IString, String>> oneBest = oneBestSearch
        .maximize(nbest);
    
    Counters.subtractInPlace(dir, wts);

    System.err.printf("Random alternate score: %.5f \n",
        emetric.score(rTrans));

    Counter<String> newWts = mert.lineSearch(nbest, newWts1, dir, emetric);
    double eval = MERT.evalAtPoint(nbest, newWts, emetric);

    double ssd = 0;
    for (String k : newWts.keySet()) {
      double diff = wts.getCount(k) - newWts.getCount(k);
      ssd += diff * diff;
    }
    System.err.printf("Eval: %.5f SSD: %e (no progress: %d)\n", eval, ssd,
        noProgress);
    wts = newWts;
    if (ssd < MERT.NO_PROGRESS_SSD)
      noProgress++;
    else
      noProgress = 0;
  }
  return wts;
}