Java Code Examples for edu.stanford.nlp.stats.Counter

The following examples show how to use edu.stanford.nlp.stats.Counter. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: wiseowl   Source File: DocumentFrequencyCounter.java    License: MIT License 6 votes vote down vote up
/**
 * Get an IDF map for the given document string.
 *
 * @param document
 * @return
 */
private static Counter<String> getIDFMapForDocument(String document) {
  // Clean up -- remove some Gigaword patterns that slow things down
  // / don't help anything
  document = headingSeparator.matcher(document).replaceAll("");

  DocumentPreprocessor preprocessor = new DocumentPreprocessor(new StringReader(document));
  preprocessor.setTokenizerFactory(tokenizerFactory);

  Counter<String> idfMap = new ClassicCounter<String>();
  for (List<HasWord> sentence : preprocessor) {
    if (sentence.size() > MAX_SENTENCE_LENGTH)
      continue;

    List<TaggedWord> tagged = tagger.tagSentence(sentence);

    for (TaggedWord w : tagged) {
      if (w.tag().startsWith("n"))
        idfMap.incrementCount(w.word());
    }
  }

  return idfMap;
}
 
Example 2
Source Project: wiseowl   Source File: DocumentFrequencyCounter.java    License: MIT License 6 votes vote down vote up
/**
 * Get an IDF map for all the documents in the given file.
 * @param file
 * @return
 */
private static Counter<String> getIDFMapForFile(Reader file)
  throws SAXException, IOException, TransformerException {

  DocumentBuilder parser = XMLUtils.getXmlParser();
  Document xml = parser.parse(new ReaderInputStream(file));
  NodeList docNodes = xml.getDocumentElement().getElementsByTagName(TAG_DOCUMENT);

  Element doc;
  Counter<String> idfMap = new ClassicCounter<String>();
  for (int i = 0; i < docNodes.getLength(); i++) {
    doc = (Element) docNodes.item(i);
    NodeList texts = doc.getElementsByTagName(TAG_TEXT);
    assert texts.getLength() == 1;

    Element text = (Element) texts.item(0);
    String textContent = getFullTextContent(text);

    idfMap.addAll(getIDFMapForDocument(textContent));

    // Increment magic counter
    idfMap.incrementCount("__all__");
  }

  return idfMap;
}
 
Example 3
public static Counter<String> features(KBPInput input) {
  // Ensure RegexNER Tags!
  input.sentence.regexner(DefaultPaths.DEFAULT_KBP_REGEXNER_CASED, false);
  input.sentence.regexner(DefaultPaths.DEFAULT_KBP_REGEXNER_CASELESS, true);

  // Get useful variables
  ClassicCounter<String> feats = new ClassicCounter<>();
  if (Span.overlaps(input.subjectSpan, input.objectSpan) || input.subjectSpan.size() == 0 || input.objectSpan.size() == 0) {
    return new ClassicCounter<>();
  }

  // Actually featurize
  denseFeatures(input, input.sentence, feats);
  surfaceFeatures(input, input.sentence, feats);
  dependencyFeatures(input, input.sentence, feats);
  relationSpecificFeatures(input, input.sentence, feats);

  return feats;
}
 
Example 4
/**
 * Score the given input, returning both the classification decision and the
 * probability of that decision.
 * Note that this method will not return a relation which does not type check.
 *
 *
 * @param input The input to classify.
 * @return A pair with the relation we classified into, along with its confidence.
 */
public Pair<String,Double> classify(KBPInput input) {
  RVFDatum<String, String> datum = new RVFDatum<>(features(input));
  Counter<String> scores =  classifier.scoresOf(datum);
  Counters.expInPlace(scores);
  Counters.normalize(scores);
  String best = Counters.argmax(scores);
  // While it doesn't type check, continue going down the list.
  // NO_RELATION is always an option somewhere in there, so safe to keep going...
  while (!NO_RELATION.equals(best) &&
      (!edu.stanford.nlp.ie.KBPRelationExtractor.RelationType.fromString(best).get().validNamedEntityLabels.contains(input.objectType) ||
       RelationType.fromString(best).get().entityType != input.subjectType) ) {
    scores.remove(best);
    Counters.normalize(scores);
    best = Counters.argmax(scores);
  }
  return Pair.makePair(best, scores.getCount(best));
}
 
Example 5
Source Project: phrasal   Source File: MakeWordClasses.java    License: GNU General Public License v3.0 6 votes vote down vote up
private int updateCountsWith(PartialStateUpdate result) {
  // Update counts
  Counters.addInPlace(classCount, result.deltaClassCount);
  Set<Integer> classes = result.deltaClassHistoryCount.firstKeySet();
  for (Integer classId : classes) {
    Counter<NgramHistory> counter = this.classHistoryCount.getCounter(classId);
    Counter<NgramHistory> delta = result.deltaClassHistoryCount.getCounter(classId);
    Counters.addInPlace(counter, delta);
  }

  // Update assignments
  int numUpdates = 0;
  for (Map.Entry<IString, Integer> assignment : result.wordToClass.entrySet()) {
    int oldAssignment = wordToClass.get(assignment.getKey());
    int newAssignment = assignment.getValue();
    if (oldAssignment != newAssignment) {
      ++numUpdates;
      wordToClass.put(assignment.getKey(), assignment.getValue());
    }
  }
  return numUpdates;
}
 
Example 6
Source Project: phrasal   Source File: OnlineTuner.java    License: GNU General Public License v3.0 6 votes vote down vote up
/**
 * Load additional feature values from plain text file.
 * Features are only updated if not already present in weight vector.
 * 
 * @param additionalFeatureWeights
 */
private void addAdditionalFeatureWeights(String additionalFeatureWeightsFile) {
  try {
    Counter<String> weights = IOTools.readWeightsPlain(additionalFeatureWeightsFile);
    System.err.println("read weights: ");
    for(Entry<String,Double> entry : weights.entrySet()) {
      if(!wtsAccumulator.containsKey(entry.getKey())) {
        wtsAccumulator.setCount(entry.getKey(), entry.getValue());
        System.err.println("setting feature: " + entry.getKey() + " = " + entry.getValue());
      }
      else System.err.println("skipping feature: " + entry.getKey());
    }
  }
  catch (IOException e) {
    e.printStackTrace();
    logger.fatal("Could not load additional weights from : {}", additionalFeatureWeightsFile);
  }
  
}
 
Example 7
Source Project: phrasal   Source File: OptimizerUtils.java    License: GNU General Public License v3.0 6 votes vote down vote up
public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) {
  List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists();
  Counter<String> featureSegmentCounts = new ClassicCounter<String>();
  for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) {
      Set<String> segmentFeatureSet = new HashSet<String>();
      for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) {
         for (FeatureValue<String> feature : trans.features) {
           segmentFeatureSet.add(feature.name);
         }
      }
      for (String featureName : segmentFeatureSet) {
        featureSegmentCounts.incrementCount(featureName);
      }
  }
  return Counters.keysAbove(featureSegmentCounts, minSegmentCount -1);
}
 
Example 8
private static Set<String> getMostFrequentTokens(LineNumberReader reader, int k) throws IOException {
  
  Counter<String> tokenCounts = new ClassicCounter<String>();
  
  String line;
  while ((line = reader.readLine()) != null) {
    String tokens[] = line.split("\\s+");
    for (String t : tokens) {
      tokenCounts.incrementCount(t);
    }
  }

  Set<String> mostFrequentTokens = new HashSet<>(k);
  Counters.retainTop(tokenCounts, k);
  mostFrequentTokens.addAll(tokenCounts.keySet());
  tokenCounts = null;
  return mostFrequentTokens;
}
 
Example 9
Source Project: phrasal   Source File: MetricUtils.java    License: GNU General Public License v3.0 6 votes vote down vote up
/**
 * Calculates the "informativeness" of each ngram, which is used by the NIST
 * metric. In Matlab notation, the informativeness of the ngram w_1:n is
 * defined as -log2(count(w_1:n)/count(w_1:n-1)).
 * 
 * @param ngramCounts
 *          ngram counts according to references
 * @param totWords
 *          total number of words, which is used to compute the
 *          informativeness of unigrams.
 */
static public <TK> Counter<Sequence<TK>> getNGramInfo(
    Counter<Sequence<TK>> ngramCounts, int totWords) {
  Counter<Sequence<TK>> ngramInfo = new ClassicCounter<Sequence<TK>>();

  for (Sequence<TK> ngram : ngramCounts.keySet()) {
    double num = ngramCounts.getCount(ngram);
    double denom = totWords;
    if (ngram.size() > 1) {
      Sequence<TK> ngramPrefix = ngram.subsequence(0,
          ngram.size() - 1);
      denom = ngramCounts.getCount(ngramPrefix);
    }
    double inf = -Math.log(num / denom) / LOG2;
    ngramInfo.setCount(ngram, inf);
    // System.err.printf("ngram info: %s %.3f\n", ngram.toString(), inf);
  }
  return ngramInfo;
}
 
Example 10
Source Project: phrasal   Source File: ScorerFactory.java    License: GNU General Public License v3.0 6 votes vote down vote up
/**
 * Creates a scorer.
 *
 * @throws IOException
 */
public static Scorer<String> factory(String scorerName, Counter<String> config, Index<String> featureIndex)
    throws IOException {

  switch (scorerName) {
    case UNIFORM_SCORER:
      return new UniformScorer<String>();
    case DENSE_SCORER:
      return new DenseScorer(config, featureIndex);
    case SPARSE_SCORER:
      return new SparseScorer(config, featureIndex);
  }

  throw new RuntimeException(String.format("Unknown scorer \"%s\"",
      scorerName));
}
 
Example 11
/**
 * True online learning, one example at a time.
 */
@Override
public Counter<String> getGradient(Counter<String> weights, Sequence<IString> source, int sourceId,
    List<RichTranslation<IString, String>> translations, List<Sequence<IString>> references,
    double[] referenceWeights, SentenceLevelMetric<IString, String> scoreMetric) {
  Objects.requireNonNull(weights);
  Objects.requireNonNull(scoreMetric);
  assert sourceId >= 0;
  assert translations.size() > 0 : "No translations for source id: " + String.valueOf(sourceId);
  assert references.size() > 0;

  // Sample from the n-best list
  List<Datum> dataset = sampleNbestList(sourceId, source, scoreMetric, translations, references);
  Counter<String> gradient = computeGradient(dataset, weights, 1);
  if (dataset.isEmpty()) {
    logger.warn("Null gradient for sourceId: {}", sourceId);
  }
  
  if (VERBOSE) {
     System.err.printf("True online gradient");
     displayGradient(gradient);
  }
 
  return gradient;
}
 
Example 12
Source Project: phrasal   Source File: ConvertWeights.java    License: GNU General Public License v3.0 6 votes vote down vote up
@SuppressWarnings("unchecked")
public static void main(String[] args) {
  if (args.length != 1) {
    System.err.printf("Usage: java %s old_wts%n", ConvertWeights.class.getName());
    System.exit(-1);
  }
  String filename = args[0];
  Counter<String> oldWeights = IOTools.deserialize(filename, ClassicCounter.class, 
      SerializationMode.DEFAULT);
  Path oldFilename = Paths.get(filename + ".old");
  try {
    Files.move(Paths.get(filename), oldFilename);
  } catch (IOException e) {
    e.printStackTrace();
    System.exit(-1);
  }
  IOTools.writeWeights(filename, oldWeights);
  System.out.printf("Converted %s to new format (old file moved to %s)%n",
      filename, oldFilename.toString());
}
 
Example 13
Source Project: wiseowl   Source File: Summarizer.java    License: MIT License 5 votes vote down vote up
private static Counter<String> getTermFrequencies(List<CoreMap> sentences) {
  Counter<String> ret = new ClassicCounter<String>();

  for (CoreMap sentence : sentences)
    for (CoreLabel cl : sentence.get(CoreAnnotations.TokensAnnotation.class))
      ret.incrementCount(cl.get(CoreAnnotations.TextAnnotation.class));

  return ret;
}
 
Example 14
Source Project: wiseowl   Source File: Summarizer.java    License: MIT License 5 votes vote down vote up
public String summarize(String document, int numSentences) {
  Annotation annotation = pipeline.process(document);
  List<CoreMap> sentences = annotation.get(CoreAnnotations.SentencesAnnotation.class);

  Counter<String> tfs = getTermFrequencies(sentences);
  sentences = rankSentences(sentences, tfs);

  StringBuilder ret = new StringBuilder();
  for (int i = 0; i < numSentences; i++) {
    ret.append(sentences.get(i));
    ret.append(" ");
  }

  return ret.toString();
}
 
Example 15
Source Project: phrasal   Source File: NISTMetric.java    License: GNU General Public License v3.0 5 votes vote down vote up
private double[] localMatchCounts(Counter<Sequence<TK>> clippedCounts) {
  double[] counts = new double[order];
  for (Sequence<TK> ngram : clippedCounts.keySet()) {
    double cnt = clippedCounts.getCount(ngram);
    if (cnt > 0) {
      int len = ngram.size();
      if (ngramInfo.containsKey(ngram))
        counts[len - 1] += cnt * ngramInfo.getCount(ngram);
      else
        System.err.println("Missing key for " + ngram.toString());
    }
  }
  return counts;
}
 
Example 16
Source Project: wiseowl   Source File: DocumentFrequencyCounter.java    License: MIT License 5 votes vote down vote up
public static void main(String[] args) throws InterruptedException, ExecutionException,
  IOException {
  ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
  List<Future<Counter<String>>> futures = new ArrayList<Future<Counter<String>>>();

  for (String filePath : args)
    futures.add(pool.submit(new FileIDFBuilder(new File(filePath))));

  int finished = 0;
  Counter<String> overall = new ClassicCounter<String>();

  for (Future<Counter<String>> future : futures) {
    System.err.printf("%s: Polling future #%d / %d%n",
        dateFormat.format(new Date()), finished + 1, args.length);
    Counter<String> result = future.get();
    finished++;
    System.err.printf("%s: Finished future #%d / %d%n",
        dateFormat.format(new Date()), finished, args.length);

    System.err.printf("\tMerging counter.. ");
    overall.addAll(result);
    System.err.printf("done.%n");
  }
  pool.shutdown();

  System.err.printf("\n%s: Saving to '%s'.. ", dateFormat.format(new Date()),
      OUT_FILE);
  ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(OUT_FILE));
  oos.writeObject(overall);
  System.err.printf("done.%n");
}
 
Example 17
Source Project: phrasal   Source File: BLEUMetric.java    License: GNU General Public License v3.0 5 votes vote down vote up
@Override
public IncrementalEvaluationMetric<TK, FV> add(int nbestId,
    Sequence<TK> translation) {
  int pos = sequences.size();
  if (pos >= maxReferenceCounts.size()) {
    throw new RuntimeException(String.format(
        "Attempt to add more candidates, %d, than references, %d.",
        pos + 1, maxReferenceCounts.size()));
  }

  if (smooth) {
    if (translation != null) {
      sequences.add(translation);
      smoothSum += getLocalSmoothScore(translation, pos, nbestId);
      smoothCnt++;
    } else {
      sequences.add(null);
    }
  } else {
    if (translation != null) {
      Counter<Sequence<TK>> candidateCounts = MetricUtils.getNGramCounts(
          translation, order);
      MetricUtils.clipCounts(candidateCounts, maxReferenceCounts.get(pos));
      sequences.add(translation);
      incCounts(candidateCounts, translation);
      c += translation.size();
      r += bestMatchLength(refLengths[pos], translation.size());
    } else {
      sequences.add(null);
    }
  }
  return this;
}
 
Example 18
public static void trainModel() throws IOException {
    forceTrack("Training data");
    List<Pair<KBPInput, String>> trainExamples = DatasetUtils.readDataset(TRAIN_FILE);
    log.info("Read " + trainExamples.size() + " examples");
    log.info("" + trainExamples.stream().map(Pair::second).filter(NO_RELATION::equals).count() + " are " + NO_RELATION);
    endTrack("Training data");

    // Featurize + create the dataset
    forceTrack("Creating dataset");
    RVFDataset<String, String> dataset = new RVFDataset<>();
    final AtomicInteger i = new AtomicInteger(0);
    long beginTime = System.currentTimeMillis();
    trainExamples.stream().parallel().forEach(example -> {
        if (i.incrementAndGet() % 1000 == 0) {
            log.info("[" + Redwood.formatTimeDifference(System.currentTimeMillis() - beginTime) +
                    "] Featurized " + i.get() + " / " + trainExamples.size() + " examples");
        }
        Counter<String> features = features(example.first);  // This takes a while per example
        synchronized (dataset) {
            dataset.add(new RVFDatum<>(features, example.second));
        }
    });
    trainExamples.clear();  // Free up some memory
    endTrack("Creating dataset");

    // Train the classifier
    log.info("Training classifier:");
    Classifier<String, String> classifier = trainMultinomialClassifier(dataset, FEATURE_THRESHOLD, SIGMA);
    dataset.clear();  // Free up some memory

    // Save the classifier
    IOUtils.writeObjectToFile(new IntelKBPStatisticalExtractor(classifier), MODEL_FILE);
}
 
Example 19
Source Project: phrasal   Source File: NISTMetric.java    License: GNU General Public License v3.0 5 votes vote down vote up
private void initReferences(List<List<Sequence<TK>>> referencesList) {
  int listSz = referencesList.size();
  for (int listI = 0; listI < listSz; listI++) {
    List<Sequence<TK>> references = referencesList.get(listI);

    int refsSz = references.size();
    if (refsSz == 0) {
      throw new RuntimeException(String.format(
          "No references found for data point: %d\n", listI));
    }

    refLengths[listI] = new int[refsSz];
    Counter<Sequence<TK>> maxReferenceCount = MetricUtils.getMaxNGramCounts(
        references, order);
    maxReferenceCounts.add(maxReferenceCount);
    refLengths[listI][0] = references.get(0).size();

    for (int refI = 1; refI < refsSz; refI++) {
      refLengths[listI][refI] = references.get(refI).size();
      Counter<Sequence<TK>> altCounts = MetricUtils.getNGramCounts(
          references.get(refI), order);
      for (Sequence<TK> sequence : new HashSet<Sequence<TK>>(
          altCounts.keySet())) {
        double cnt = maxReferenceCount.getCount(sequence);
        double altCnt = altCounts.getCount(sequence);
        if (cnt < altCnt) {
          maxReferenceCount.setCount(sequence, altCnt);
        }
      }
    }
  }
}
 
Example 20
Source Project: phrasal   Source File: ClustererState.java    License: GNU General Public License v3.0 5 votes vote down vote up
public ClustererState(List<IString> vocabularySubset, Counter<IString> wordCount,
    TwoDimensionalCounter<IString, NgramHistory> historyCount, Map<IString, Integer> inWordToClass,
    Counter<Integer> inClassCount,
    TwoDimensionalCounter<Integer, NgramHistory> inClassHistoryCount, int numClasses, 
    double currentObjectiveValue) {
  this.vocabularySubset = vocabularySubset;
  this.wordCount = wordCount;
  this.historyCount = historyCount;
  this.wordToClass = inWordToClass;
  this.classCount = inClassCount;
  this.classHistoryCount = inClassHistoryCount;
  this.numClasses = numClasses;
  this.currentObjectiveValue = currentObjectiveValue;
}
 
Example 21
private Set<IString> loadCountsFile(String filename) {
  Counter<IString> counter = new ClassicCounter<IString>();
  LineNumberReader reader = IOTools.getReaderFromFile(filename);
  try {
    for (String line; (line = reader.readLine()) != null;) {
      String[] fields = line.trim().split("\\s+");
      if (fields.length == 2) {
        String wordType = fields[0];
        if ( ! (TokenUtils.isNumericOrPunctuationOrSymbols(wordType) ||
                wordType.equals(TokenUtils.START_TOKEN.toString()) ||
                wordType.equals(TokenUtils.END_TOKEN.toString()))) {
          counter.setCount(new IString(wordType), Double.valueOf(fields[1]));
        }
      } else {
        System.err.printf("%s: Discarding line %s%n", this.getClass().getName(), line);
      }
    }
    reader.close();
    Set<IString> set = new HashSet<>(Counters.topKeys(counter, rankCutoff));
    for (IString word : set) {
      System.err.printf(" %s%n", word);
    }
    return set;
    
  } catch (IOException e) {
    throw new RuntimeException(e);
  }
}
 
Example 22
Source Project: phrasal   Source File: OnlineTuner.java    License: GNU General Public License v3.0 5 votes vote down vote up
public ProcessorOutput(Counter<String> gradient, 
    int inputId, 
    List<List<RichTranslation<IString, String>>> nbestLists, int[] translationIds, List<SymmetricalWordAlignment> wordAlignments,
    List<RichTranslation<IString, String>> prefixDecodingOutput) {
  this.gradient = gradient;
  this.inputId = inputId;
  this.nbestLists = nbestLists;
  this.translationIds = translationIds;
  this.wordAlignments = wordAlignments;
  this.prefixDecodingOutput = prefixDecodingOutput;
}
 
Example 23
Source Project: phrasal   Source File: MetricUtils.java    License: GNU General Public License v3.0 5 votes vote down vote up
/**
 * 
 * @param <TK>
 */
static public <TK> Counter<Sequence<TK>> getNGramCounts(Sequence<TK> sequence, int maxOrder) {
  Counter<Sequence<TK>> counts = new ClassicCounter<>();
  int sz = sequence.size();
  for (int i = 0; i < sz; i++) {
    int jMax = Math.min(sz, i + maxOrder);
    for (int j = i + 1; j <= jMax; j++) {
      Sequence<TK> ngram = sequence.subsequence(i, j);
      counts.incrementCount(ngram);
    }
  }
  return counts;
}
 
Example 24
Source Project: phrasal   Source File: MERT.java    License: GNU General Public License v3.0 5 votes vote down vote up
static void displayWeights(Counter<String> wts) {

    List<Pair<String,Double>> wtsList = Counters.toDescendingMagnitudeSortedListWithCounts(wts);
    if (wtsList.size() > 100) {
      wtsList = wtsList.subList(0, 100);
    }

    for (Pair<String, Double> p : wtsList) {
      System.out.printf("%s %g\n", p.first, p.second);
    }
  }
 
Example 25
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = new ClassicCounter<String>(initialWts);
  Counters.normalize(wts);
  double seedSeed = Math.abs(Counters.max(wts));
  long seed = (long)Math.exp(Math.log(seedSeed) + Math.log(Long.MAX_VALUE));
  System.err.printf("PRO thread using random seed: %d\n", seed);
  RVFDataset<String, String> proSamples = getSamples(new Random(seed));
  LogPrior lprior = new LogPrior();
  lprior.setSigma(l2sigma);
  LogisticClassifierFactory<String,String> lcf = new LogisticClassifierFactory<String,String>();
  LogisticClassifier<String, String> lc = lcf.trainClassifier(proSamples, lprior, false);
  Counter<String> decoderWeights = new ClassicCounter<String>(); 
  Counter<String> lcWeights = lc.weightsAsCounter();
  for (String key : lcWeights.keySet()) {
    double mul;
    if (key.startsWith("1 / ")) {
      mul = 1.0;
    } else if (key.startsWith("0 / ")) {
      mul = -1.0;
    } else {
      throw new RuntimeException("Unparsable weight name produced by logistic classifier: "+key);
    }
    String decoderKey = key.replaceFirst("^[10] / ", "");
    decoderWeights.incrementCount(decoderKey, mul*lcWeights.getCount(key));
  }

  synchronized (MERT.bestWts) {
    if (!updatedBestOnce) {
      System.err.println("Force updating weights (once)");
      double metricEval = MERT.evalAtPoint(nbest, decoderWeights, emetric);
      MERT.updateBest(decoderWeights, metricEval, true);
      updatedBestOnce = true;
    }
  }
  return decoderWeights;
}
 
Example 26
/**
 * Constructor.
 * 
 * @param initialRate
 * @param expectedNumFeatures
 * @param L1lambda
 * @param customL1
 * @param fixedFeatures
 */
public AdaGradFastFOBOSUpdater(double initialRate, int expectedNumFeatures, double L1lambda, 
    Counter<String> customL1, Set<String> fixedFeatures) {
  this.rate = initialRate;
  this.L1lambda = L1lambda;
  sumGradSquare = new ClassicCounter<>(expectedNumFeatures);
  lastUpdated = new ClassicCounter<>(expectedNumFeatures);
  this.customL1 = customL1;
  this.fixedFeatures = fixedFeatures;
}
 
Example 27
public AdaGradFastFOBOSState(Counter<String> h, Counter<String> r, Set<String> f, Counter<String> u, int t) {
  this.gradHistory = h;
  this.customReg = r;
  this.fixedFeatures = f;
  this.lastUp = u;
  this.timeStep = t;
}
 
Example 28
Source Project: phrasal   Source File: SequenceOptimizer.java    License: GNU General Public License v3.0 5 votes vote down vote up
@Override
public Counter<String> optimize(Counter<String> initialWts) {
  Counter<String> wts = initialWts;
  for (BatchOptimizer opt : opts) {

    boolean done = false;

    while (!done) {
      Counter<String> newWts = opt.optimize(wts);

      double wtSsd = MERT.wtSsd(newWts, wts);

      double oldE = MERT.evalAtPoint(nbest, wts, emetric);
      double newE = MERT.evalAtPoint(nbest, newWts, emetric);
      // MERT.updateBest(newWts, -newE);

      boolean worse = oldE > newE;
      done = Math.abs(oldE - newE) <= MIN_OBJECTIVE_CHANGE || !loop || worse;

      System.err.printf(
          "seq optimizer: %s -> %s (%s) ssd: %f done: %s opt: %s\n", oldE,
          newE, newE - oldE, wtSsd, done, opt.toString());

      if (worse)
        System.err.printf("WARNING: negative objective change!");
      else
        wts = newWts;
    }
  }
  return wts;
}
 
Example 29
private Counter<String> vectorToWeights(double[] x) {
  Counter<String> wts = new ClassicCounter<String>();
  for (int i = 0; i < weightNames.length; i++) {
    wts.setCount(weightNames[i], x[i]);
  }
  return wts;
}
 
Example 30
@Override
public Counter<String> getBatchGradient(Counter<String> weights,
    List<Sequence<IString>> sources, int[] sourceIds,
    List<List<RichTranslation<IString, String>>> translations,
    List<List<Sequence<IString>>> references,
    double[] referenceWeights, SentenceLevelMetric<IString, String> scoreMetric) {
  throw new UnsupportedOperationException("1-best MIRA does not support mini-batch learning");
}