Java Code Examples for edu.stanford.nlp.stats.Counter#setCount()

The following examples show how to use edu.stanford.nlp.stats.Counter#setCount() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AdaGradFOBOSUpdater.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
public void updateL1(Counter<String> weights,
     Counter<String> gradient, int timeStep) {
  // w_{t+1} := w_t - nu*g_t
  for (String feature : gradient.keySet()) {
    double gValue = gradient.getCount(feature);
    double sgsValue = sumGradSquare.incrementCount(feature, gValue*gValue);
    double wValue = weights.getCount(feature);
    double currentrate = rate / (Math.sqrt(sgsValue)+eps);
    double testupdate = wValue - (currentrate * gValue);
    double realupdate = Math.signum(testupdate) * pospart( Math.abs(testupdate) - currentrate*this.lambda );
    if (realupdate == 0.0) {
      // Filter zeros
      weights.remove(feature);
    } else {
      weights.setCount(feature, realupdate);
    }
  }
}
 
Example 2
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Compute maximum n-gram counts from one or more sequences.
 * 
 * @param sequences - The list of sequences.
 * @param maxOrder - The n-gram order.
 */
static public <TK> Counter<Sequence<TK>> getMaxNGramCounts(
    List<Sequence<TK>> sequences, double[] seqWeights, int maxOrder) {
  Counter<Sequence<TK>> maxCounts = new ClassicCounter<Sequence<TK>>();
  maxCounts.setDefaultReturnValue(0.0);
  if(seqWeights != null && seqWeights.length != sequences.size()) {
    throw new RuntimeException("Improper weight vector for sequences.");
  }
  
  int seqId = 0;
  for (Sequence<TK> sequence : sequences) {
    Counter<Sequence<TK>> counts = getNGramCounts(sequence, maxOrder);
    for (Sequence<TK> ngram : counts.keySet()) {
      double weight = seqWeights == null ? 1.0 : seqWeights[seqId];
      double countValue = weight * counts.getCount(ngram);
      double currentMax = maxCounts.getCount(ngram);
      maxCounts.setCount(ngram, Math.max(countValue, currentMax));
    }
    ++seqId;
  }
  return maxCounts;
}
 
Example 3
Source File: MetricUtils.java    From phrasal with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Calculates the "informativeness" of each ngram, which is used by the NIST
 * metric. In Matlab notation, the informativeness of the ngram w_1:n is
 * defined as -log2(count(w_1:n)/count(w_1:n-1)).
 * 
 * @param ngramCounts
 *          ngram counts according to references
 * @param totWords
 *          total number of words, which is used to compute the
 *          informativeness of unigrams.
 */
static public <TK> Counter<Sequence<TK>> getNGramInfo(
    Counter<Sequence<TK>> ngramCounts, int totWords) {
  Counter<Sequence<TK>> ngramInfo = new ClassicCounter<Sequence<TK>>();

  for (Sequence<TK> ngram : ngramCounts.keySet()) {
    double num = ngramCounts.getCount(ngram);
    double denom = totWords;
    if (ngram.size() > 1) {
      Sequence<TK> ngramPrefix = ngram.subsequence(0,
          ngram.size() - 1);
      denom = ngramCounts.getCount(ngramPrefix);
    }
    double inf = -Math.log(num / denom) / LOG2;
    ngramInfo.setCount(ngram, inf);
    // System.err.printf("ngram info: %s %.3f\n", ngram.toString(), inf);
  }
  return ngramInfo;
}
 
Example 4
Source File: OnlineTuner.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Configure weights stored on file.
 * @param translationModel 
 */
private static Counter<String> loadWeights(String wtsInitialFile,
    boolean uniformStartWeights, boolean randomizeStartWeights, TranslationModel<IString, String> translationModel) {

  Counter<String> weights = IOTools.readWeights(wtsInitialFile);
  if (weights == null) weights = new ClassicCounter<>();
  if (uniformStartWeights) {
    // Initialize according to Moses heuristic
    Set<String> featureNames = new HashSet<>(weights.keySet());
    featureNames.addAll(FeatureUtils.getBaselineFeatures(translationModel));
    for (String key : featureNames) {
      if (key.startsWith(NGramLanguageModelFeaturizer.DEFAULT_FEATURE_NAME)) {
        weights.setCount(key, 0.5);
      } else if (key.startsWith(WordPenaltyFeaturizer.FEATURE_NAME)) {
        weights.setCount(key, -1.0);
      } else {
        weights.setCount(key, 0.2);
      }
    }
  }
  if (randomizeStartWeights) {
    // Add some random noise
    double scale = 1e-4;
    OptimizerUtils.randomizeWeightsInPlace(weights, scale);
  }
  return weights;
}
 
Example 5
Source File: DownhillSimplexOptimizer.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private Counter<String> vectorToWeights(double[] x) {
  Counter<String> wts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    wts.setCount(weightNames[i], x[i]);
  }
  return wts;
}
 
Example 6
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
public static Counter<String> getWeightCounterFromArray(String[] weightNames,
    double[] wtsArr) {
  Counter<String> wts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    wts.setCount(weightNames[i], wtsArr[i]);
  }
  return wts;
}
 
Example 7
Source File: OptimizerUtils.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Add a scaled (positive) random vector to a weights vector.
 * 
 * @param wts
 * @param scale
 */
public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
  for (String feature : wts.keySet()) {
    double epsilon = Math.random() * scale;
    double newValue = wts.getCount(feature) + epsilon;
    wts.setCount(feature, newValue);
  }
}
 
Example 8
Source File: MERT.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
static Counter<String> randomWts(Set<String> keySet) {
  Counter<String> randpt = new ClassicCounter<String>();
  for (String f : keySet) {
    randpt.setCount(f, globalRandom.nextDouble());
  }
  System.err.printf("random Wts: %s%n", randpt);
  return randpt;
}
 
Example 9
Source File: NISTMetric.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private void initReferences(List<List<Sequence<TK>>> referencesList) {
  int listSz = referencesList.size();
  for (int listI = 0; listI < listSz; listI++) {
    List<Sequence<TK>> references = referencesList.get(listI);

    int refsSz = references.size();
    if (refsSz == 0) {
      throw new RuntimeException(String.format(
          "No references found for data point: %d\n", listI));
    }

    refLengths[listI] = new int[refsSz];
    Counter<Sequence<TK>> maxReferenceCount = MetricUtils.getMaxNGramCounts(
        references, order);
    maxReferenceCounts.add(maxReferenceCount);
    refLengths[listI][0] = references.get(0).size();

    for (int refI = 1; refI < refsSz; refI++) {
      refLengths[listI][refI] = references.get(refI).size();
      Counter<Sequence<TK>> altCounts = MetricUtils.getNGramCounts(
          references.get(refI), order);
      for (Sequence<TK> sequence : new HashSet<Sequence<TK>>(
          altCounts.keySet())) {
        double cnt = maxReferenceCount.getCount(sequence);
        double altCnt = altCounts.getCount(sequence);
        if (cnt < altCnt) {
          maxReferenceCount.setCount(sequence, altCnt);
        }
      }
    }
  }
}
 
Example 10
Source File: OverrideBinwts.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
public static void main(String[] args) {
  if(args.length != 3) {
    usage();
    System.exit(-1);
  }
    
  String input = args[0];
  String overrides = args[1];
  String output = args[2];
  
  System.err.println("reading weights from " + input);
  
  Counter<String> weights = IOTools.readWeights(input);
  
  try {
    Counter<String> overridesW = IOTools.readWeightsPlain(overrides);
    System.err.println("read weights from  " + overrides + ":");
    for(Entry<String,Double> entry : overridesW.entrySet()) {
      if(entry.getValue() == 0) weights.remove(entry.getKey());
      else weights.setCount(entry.getKey(), entry.getValue());
      System.err.println("setting feature: " + entry.getKey() + " = " + entry.getValue());
    }
  }
  catch (IOException e) {
    e.printStackTrace();
    System.exit(-1);
  }

  System.err.println("writing weights to " + output);
  
  IOTools.writeWeights(output, weights);
  
}
 
Example 11
Source File: IOTools.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Read weights from a plain text file.
 * 
 * @param filename
 * @return
 * @throws IOException
 */
public static Counter<String> readWeightsPlain(String filename) throws IOException {
  LineNumberReader reader = new LineNumberReader(new FileReader(filename));   
  Counter<String> wts = new ClassicCounter<String>();
  for (String line; (line = reader.readLine()) != null;) {
    String[] input = line.split(" ");
    if(input.length != 2) {
      reader.close();
      throw new IOException("Illegal input in weight file " + filename + ": " + line);
    }
    wts.setCount(input[0],Double.parseDouble(input[1]));
  }
  reader.close();
  return wts;
}
 
Example 12
Source File: TargetFunctionWordInsertion.java    From phrasal with GNU General Public License v3.0 5 votes vote down vote up
private Set<IString> loadCountsFile(String filename) {
  Counter<IString> counter = new ClassicCounter<IString>();
  LineNumberReader reader = IOTools.getReaderFromFile(filename);
  try {
    for (String line; (line = reader.readLine()) != null;) {
      String[] fields = line.trim().split("\\s+");
      if (fields.length == 2) {
        String wordType = fields[0];
        if ( ! (TokenUtils.isNumericOrPunctuationOrSymbols(wordType) ||
                wordType.equals(TokenUtils.START_TOKEN.toString()) ||
                wordType.equals(TokenUtils.END_TOKEN.toString()))) {
          counter.setCount(new IString(wordType), Double.valueOf(fields[1]));
        }
      } else {
        System.err.printf("%s: Discarding line %s%n", this.getClass().getName(), line);
      }
    }
    reader.close();
    Set<IString> set = new HashSet<>(Counters.topKeys(counter, rankCutoff));
    for (IString word : set) {
      System.err.printf(" %s%n", word);
    }
    return set;
    
  } catch (IOException e) {
    throw new RuntimeException(e);
  }
}
 
Example 13
Source File: DownhillSimplexOptimizer.java    From phrasal with GNU General Public License v3.0 4 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);
      
  // create a mapping between weight names and optimization
  // weight vector positions
  String[] weightNames = new String[initialWts.size()];
  double[] initialWtsArr = new double[initialWts.size()];

  int nameIdx = 0;
  for (String feature : wts.keySet()) {
    initialWtsArr[nameIdx] = wts.getCount(feature);
    weightNames[nameIdx++] = feature;
  }

  Minimizer<Function> dhsm = new DownhillSimplexMinimizer();

  MERTObjective mo = new MERTObjective(weightNames);
  
  double initialValueAt = mo.valueAt(initialWtsArr);
  if (initialValueAt == Double.POSITIVE_INFINITY
      || initialValueAt != initialValueAt) {
    System.err
        .printf("Initial Objective is infinite/NaN - normalizing weight vector");
    double normTerm = Counters.L2Norm(wts);
    for (int i = 0; i < initialWtsArr.length; i++) {
      initialWtsArr[i] /= normTerm;
    }
  }
  
  double initialObjValue = mo.valueAt(initialWtsArr);

  System.err.println("Initial Objective value: " + initialObjValue);
  double newX[] = dhsm.minimize(mo, 1e-6, initialWtsArr); // new
                                                         // double[wts.size()]

  Counter<String> newWts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    newWts.setCount(weightNames[i], newX[i]);
  }

  double finalObjValue = mo.valueAt(newX);
  
  System.err.println("Final Objective value: " + finalObjValue);
  double metricEval = MERT.evalAtPoint(nbest, newWts, emetric);
  MERT.updateBest(newWts, metricEval);
  return newWts;
}