/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.joshua.decoder.ff.lm.bloomfilter_lm; import java.io.Externalizable; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.util.HashMap; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import org.apache.joshua.corpus.Vocabulary; import org.apache.joshua.decoder.ff.lm.DefaultNGramLanguageModel; import org.apache.joshua.util.Regex; import org.apache.joshua.util.io.LineReader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * An n-gram language model with linearly-interpolated Witten-Bell smoothing, using a Bloom filter * as its main data structure. A Bloom filter is a lossy data structure that can be used to test for * set membership. */ public class BloomFilterLanguageModel extends DefaultNGramLanguageModel implements Externalizable { /** * An initial value used for hashing n-grams so that they can be stored in a bloom filter. */ public static final int HASH_SEED = 17; /** * Another value used in the process of hashing n-grams. */ public static final int HASH_OFFSET = 37; /** * The maximum score that a language model feature function can return to the Joshua decoder. */ public static final double MAX_SCORE = 100.0; /** * The logger for this class. */ private static final Logger LOG = LoggerFactory.getLogger(BloomFilterLanguageModel.class); /** * The Bloom filter data structure itself. */ private BloomFilter bf; /** * The base of the logarithm used to quantize n-gram counts. N-gram counts are quantized * logarithmically to reduce the number of times we need to query the Bloom filter. */ private double quantizationBase; /** * Natural log of the number of tokens seen in the training corpus. */ private double numTokens; /** * An array of pairs of long, used as hash functions for storing or retreiving the count of an * n-gram in the Bloom filter. */ private long[][] countFuncs; /** * An array of pairs of long, used as hash functions for storing or retreiving the number of * distinct types observed after an n-gram. */ private long[][] typesFuncs; /** * The smoothed probability of an unseen n-gram. This is also the probability of any n-gram under * the zeroth-order model. */ transient private double p0; /** * The interpolation constant between Witten-Bell models of order zero and one. Stored in a field * because it can be calculated ahead of time; it doesn't depend on the particular n-gram. */ transient private double lambda0; /** * The maximum possible quantized count of any n-gram stored in the Bloom filter. Used as an upper * bound on the count that could be returned when querying the Bloom filter. */ transient private int maxQ; // max quantized count /** * Constructor called from the Joshua decoder. This constructor assumes that the LM has already * been built, and takes the name of the file where the LM is stored. * * @param order the order of the language model * @param filename path to the file where the language model is stored * @throws IOException if the bloom filter language model cannot be rebuilt from the input file */ public BloomFilterLanguageModel(int order, String filename) throws IOException { super(order); try { readExternal(new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename)))); } catch (ClassNotFoundException e) { IOException ioe = new IOException("Could not rebuild bloom filter LM from file " + filename); ioe.initCause(e); throw ioe; } int vocabSize = Vocabulary.size(); p0 = -Math.log(vocabSize + 1); double oneMinusLambda0 = numTokens - logAdd(Math.log(vocabSize), numTokens); p0 += oneMinusLambda0; lambda0 = Math.log(vocabSize) - logAdd(Math.log(vocabSize), numTokens); maxQ = quantize((long) Math.exp(numTokens)); } /** * Constructor to be used by the main function. This constructor is used to build a new language * model from scratch. An LM should be built with the main function before using it in the Joshua * decoder. * * @param filename path to the file of training corpus statistics * @param order the order of the language model * @param size the size of the Bloom filter, in bits * @param base a double. The base of the logarithm for quantization. */ private BloomFilterLanguageModel(String filename, int order, int size, double base) { super(order); quantizationBase = base; populateBloomFilter(size, filename); } /** * calculates the linearly-interpolated Witten-Bell probability for a given ngram. this is * calculated as: p(w|h) = pML(w|h)L(h) - (1 - L(h))p(w|h') where: w is a word and h is a history * h' is the history h with the first word removed pML is the maximum-likelihood estimate of the * probability L(.) is lambda, the interpolation factor, which depends only on the history h: L(h) * = s(h) / s(h) + c(h) where s(.) is the observed number of distinct types after h, and c is the * observed number of counts of h in the training corpus. * <p> * in fact this model calculates the probability starting from the lowest order and working its * way up, to take advantage of the one- sided error rate inherent in using a bloom filter data * structure. * * @param ngram the ngram whose probability is to be calculated * @param ngramOrder the order of the ngram. * * @return the linearly-interpolated Witten-Bell smoothed probability of an ngram */ private float wittenBell(int[] ngram, int ngramOrder) { int end = ngram.length; double p = p0; // current calculated probability // note that p0 and lambda0 are independent of the given // ngram so they are calculated ahead of time. int MAX_QCOUNT = getCount(ngram, ngram.length - 1, ngram.length, maxQ); if (MAX_QCOUNT == 0) // OOV! return (float) p; double pML = Math.log(unQuantize(MAX_QCOUNT)) - numTokens; // p += lambda0 * pML; p = logAdd(p, (lambda0 + pML)); if (ngram.length == 1) { // if it's a unigram, we're done return (float) p; } // otherwise we calculate the linear interpolation // with higher order models. for (int i = end - 2; i >= end - ngramOrder && i >= 0; i--) { int historyCnt = getCount(ngram, i, end, MAX_QCOUNT); // if the count for the history is zero, all higher // terms in the interpolation must be zero, so we // are done here. if (historyCnt == 0) { return (float) p; } int historyTypesAfter = getTypesAfter(ngram, i, end, historyCnt); // unQuantize the counts we got from the BF double HC = unQuantize(historyCnt); double HTA = 1 + unQuantize(historyTypesAfter); // interpolation constant double lambda = Math.log(HTA) - Math.log(HTA + HC); double oneMinusLambda = Math.log(HC) - Math.log(HTA + HC); // p *= 1 - lambda p += oneMinusLambda; int wordCount = getCount(ngram, i + 1, end, historyTypesAfter); double WC = unQuantize(wordCount); // p += lambda * p_ML(w|h) if (WC == 0) return (float) p; p = logAdd(p, lambda + Math.log(WC) - Math.log(HC)); MAX_QCOUNT = wordCount; } return (float) p; } /** * Retrieve the count of a ngram from the Bloom filter. That is, how many times did we see this * ngram in the training corpus? This corresponds roughly to algorithm 2 in Talbot and Osborne's * "Tera-Scale LMs on the Cheap." * * @param ngram array containing the ngram as a sub-array * @param start the index of the first word of the ngram * @param end the index after the last word of the ngram * @param qcount the maximum possible count to be returned * * @return the number of times the ngram was seen in the training corpus, quantized */ private int getCount(int[] ngram, int start, int end, int qcount) { for (int i = 1; i <= qcount; i++) { int hash = hashNgram(ngram, start, end, i); if (!bf.query(hash, countFuncs)) { return i - 1; } } return qcount; } /** * Retrieve the number of distinct types that follow an ngram in the training corpus. * * This is another version of algorithm 2. As noted in the paper, we have different algorithms for * getting ngram counts versus suffix counts because c(x) = 1 is a proxy item for s(x) = 1 * * @param ngram an array the contains the ngram as a sub-array * @param start the index of the first word of the ngram * @param end the index after the last word of the ngram * @param qcount the maximum possible return value * * @return the number of distinct types observed to follow an ngram in the training corpus, * quantized */ private int getTypesAfter(int[] ngram, int start, int end, int qcount) { // first we check c(x) >= 1 int hash = hashNgram(ngram, start, end, 1); if (!bf.query(hash, countFuncs)) { return 0; } // if c(x) >= 1, we check for the stored suffix count for (int i = 1; i < qcount; i++) { hash = hashNgram(ngram, start, end, i); if (!bf.query(hash, typesFuncs)) { return i - 1; } } return qcount; } /** * Logarithmically quantizes raw counts. The quantization scheme is described in Talbot and * Osborne's paper "Tera-Scale LMs on the Cheap." * * @param x long giving the raw count to be quantized * * @return the quantized count */ private int quantize(long x) { return 1 + (int) Math.floor(Math.log(x) / Math.log(quantizationBase)); } /** * Unquantizes a quantized count. * * @param x the quantized count * * @return the expected raw value of the quantized count */ private double unQuantize(int x) { if (x == 0) { return 0; } else { return ((quantizationBase + 1) * Math.pow(quantizationBase, x - 1) - 1) / 2; } } /** * Converts an n-gram and a count into a value that can be stored into a Bloom filter. This is * adapted directly from <code>AbstractPhrase.hashCode()</code> elsewhere in the Joshua code base. * * @param ngram an array containing the ngram as a sub-array * @param start the index of the first word of the ngram * @param end the index after the last word of the ngram * @param val the count of the ngram * * @return a value suitable to be stored in a Bloom filter */ private int hashNgram(int[] ngram, int start, int end, int val) { int result = HASH_OFFSET * HASH_SEED + val; for (int i = start; i < end; i++) result = HASH_OFFSET * result + ngram[i]; return result; } /** * Adds two numbers that are in the log domain, avoiding underflow. * * @param x one summand * @param y the other summand * * @return the log of the sum of the exponent of the two numbers. */ private static double logAdd(double x, double y) { if (y <= x) { return x + Math.log1p(Math.exp(y - x)); } else { return y + Math.log1p(Math.exp(x - y)); } } /** * Builds a language model and stores it in a file. * * @param argv command-line arguments */ public static void main(String[] argv) { if (argv.length < 5) { String msg = "usage: BloomFilterLanguageModel <statistics file> <order> <size>" + " <quantization base> <output file>"; System.err.println(msg); LOG.error(msg); return; } int order = Integer.parseInt(argv[1]); int size = (int) (Integer.parseInt(argv[2]) * Math.pow(2, 23)); double base = Double.parseDouble(argv[3]); try { BloomFilterLanguageModel lm = new BloomFilterLanguageModel(argv[0], order, size, base); ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(argv[4]))); lm.writeExternal(out); out.close(); //TODO: try-with-resources } catch (IOException e) { LOG.error(e.getMessage(), e); } } /** * Adds ngram counts and counts of distinct types after ngrams, read from a file, to the Bloom * filter. * <p> * The file format should look like this: ngram1 count types-after ngram2 count types-after ... * * @param bloomFilterSize the size of the Bloom filter, in bits * @param filename path to the statistics file */ private void populateBloomFilter(int bloomFilterSize, String filename) { HashMap<String, Long> typesAfter = new HashMap<>(); try { FileInputStream file_in = new FileInputStream(filename); FileInputStream file_in_copy = new FileInputStream(filename); InputStream in; InputStream estimateStream; if (filename.endsWith(".gz")) { in = new GZIPInputStream(file_in); estimateStream = new GZIPInputStream(file_in_copy); } else { in = file_in; estimateStream = file_in_copy; } int numObjects = estimateNumberOfObjects(estimateStream); LOG.debug("Estimated number of objects: {}", numObjects); bf = new BloomFilter(bloomFilterSize, numObjects); countFuncs = bf.initializeHashFunctions(); populateFromInputStream(in, typesAfter); in.close(); } catch (IOException e) { LOG.error(e.getMessage(), e); return; } typesFuncs = bf.initializeHashFunctions(); for (String history : typesAfter.keySet()) { String[] toks = Regex.spaces.split(history); int[] hist = new int[toks.length]; for (int i = 0; i < toks.length; i++) hist[i] = Vocabulary.id(toks[i]); add(hist, typesAfter.get(history), typesFuncs); } } /** * Estimate the number of objects that will be stored in the Bloom filter. The optimum number of * hash functions depends on the number of items that will be stored, so we want a guess before we * begin to read the statistics file and store it. * * @param source an InputStream pointing to the training corpus stats * * @return an estimate of the number of objects to be stored in the Bloom filter */ private int estimateNumberOfObjects(InputStream source) { int numLines = 0; long maxCount = 0; for (String line: new LineReader(source)) { if (line.trim().equals("")) continue; String[] toks = Regex.spaces.split(line); if (toks.length > ngramOrder + 1) continue; try { long cnt = Long.parseLong(toks[toks.length - 1]); if (cnt > maxCount) maxCount = cnt; } catch (NumberFormatException e) { LOG.error(e.getMessage(), e); break; } numLines++; } double estimate = Math.log(maxCount) / Math.log(quantizationBase); return (int) Math.round(numLines * estimate); } /** * Reads the statistics from a source and stores them in the Bloom filter. The ngram counts are * stored immediately in the Bloom filter, but the counts of distinct types following each ngram * are accumulated from the file as we go. * * @param source an InputStream pointing to the statistics * @param types a HashMap that will stores the accumulated counts of distinct types observed to * follow each ngram */ private void populateFromInputStream(InputStream source, HashMap<String, Long> types) { numTokens = Double.NEGATIVE_INFINITY; // = log(0) for (String line: new LineReader(source)) { String[] toks = Regex.spaces.split(line); if ((toks.length < 2) || (toks.length > ngramOrder + 1)) continue; int[] ngram = new int[toks.length - 1]; StringBuilder history = new StringBuilder(); for (int i = 0; i < toks.length - 1; i++) { ngram[i] = Vocabulary.id(toks[i]); if (i < toks.length - 2) history.append(toks[i]).append(" "); } long cnt = Long.parseLong(toks[toks.length - 1]); add(ngram, cnt, countFuncs); if (toks.length == 2) { // unigram numTokens = logAdd(numTokens, Math.log(cnt)); // no need to count types after "" // that's what vocabulary.size() is for. continue; } if (types.get(history) == null) types.put(history.toString(), 1L); else { long x = types.get(history); types.put(history.toString(), x + 1); } } } /** * Adds an ngram, along with an associated value, to the Bloom filter. This corresponds to Talbot * and Osborne's "Tera-scale LMs on the cheap", algorithm 1. * * @param ngram an array representing the ngram * @param value the value to be associated with the ngram * @param funcs an array of long to be used as hash functions */ private void add(int[] ngram, long value, long[][] funcs) { if (ngram == null) return; int qValue = quantize(value); for (int i = 1; i <= qValue; i++) { int hash = hashNgram(ngram, 0, ngram.length, i); bf.add(hash, funcs); } } /** * Read a Bloom filter LM from an external file. * * @param in an ObjectInput stream to read from */ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { int vocabSize = in.readInt(); for (int i = 0; i < vocabSize; i++) { String line = in.readUTF(); Vocabulary.id(line); } numTokens = in.readDouble(); countFuncs = new long[in.readInt()][2]; for (int i = 0; i < countFuncs.length; i++) { countFuncs[i][0] = in.readLong(); countFuncs[i][1] = in.readLong(); } typesFuncs = new long[in.readInt()][2]; for (int i = 0; i < typesFuncs.length; i++) { typesFuncs[i][0] = in.readLong(); typesFuncs[i][1] = in.readLong(); } quantizationBase = in.readDouble(); bf = new BloomFilter(); bf.readExternal(in); } /** * Write a Bloom filter LM to some external location. * * @param out an ObjectOutput stream to write to * * @throws IOException if an input or output exception occurred */ public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(Vocabulary.size()); for (int i = 0; i < Vocabulary.size(); i++) { // out.writeBytes(vocabulary.getWord(i)); // out.writeChar('\n'); // newline out.writeUTF(Vocabulary.word(i)); } out.writeDouble(numTokens); out.writeInt(countFuncs.length); for (long[] countFunc : countFuncs) { out.writeLong(countFunc[0]); out.writeLong(countFunc[1]); } out.writeInt(typesFuncs.length); for (long[] typesFunc : typesFuncs) { out.writeLong(typesFunc[0]); out.writeLong(typesFunc[1]); } out.writeDouble(quantizationBase); bf.writeExternal(out); } /** * Returns the language model score for an n-gram. This is called from the rest of the Joshua * decoder. * * @param ngram the ngram to score * @param order the order of the model * * @return the language model score of the ngram */ @Override protected float ngramLogProbability_helper(int[] ngram, int order) { int[] lm_ngram = new int[ngram.length]; for (int i = 0; i < ngram.length; i++) { lm_ngram[i] = Vocabulary.id(Vocabulary.word(ngram[i])); } return wittenBell(lm_ngram, order); } @Override public boolean isOov(int id) { int[] ngram = new int[] {id}; int MAX_QCOUNT = getCount(ngram, ngram.length - 1, ngram.length, maxQ); return (MAX_QCOUNT == 0); } }