package lucene4ir.utils; import lucene4ir.Lucene4IRConstants; import org.apache.lucene.index.*; import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.util.BytesRef; import java.io.IOException; import java.util.HashMap; import java.util.Map; /** * Created by leif on 14/09/2017. */ public class LanguageModel { protected IndexReader reader; private IndexSearcher searcher; private CollectionStatistics collectionStats; public String field = Lucene4IRConstants.FIELD_ALL; private int[] doc_ids; private double doc_len; private HashMap<String, Double> termCounts = new HashMap<>(); private long token_count; public LanguageModel(IndexReader ir, int doc_id) { reader = ir; searcher = new IndexSearcher(reader); doc_ids = new int[1]; doc_ids[0] = doc_id; doc_len = 0.0; updateTermCountMap(doc_id, 1.0); try { collectionStats = searcher.collectionStatistics(field); token_count = collectionStats.sumTotalTermFreq(); } catch (IOException e) { e.printStackTrace(); System.exit(1); } } public LanguageModel(IndexReader ir, int[] doc_ids) { reader = ir; searcher = new IndexSearcher(reader); this.doc_ids = doc_ids; doc_len = 0.0; for (int doc_id : doc_ids) { updateTermCountMap(doc_id, 1.0); } try { collectionStats = searcher.collectionStatistics(field); token_count = collectionStats.sumTotalTermFreq(); } catch (IOException e) { e.printStackTrace(); System.exit(1); } } public LanguageModel(IndexReader ir, int[] doc_ids, double[] weights) { assert doc_ids.length == weights.length; reader = ir; searcher = new IndexSearcher(reader); this.doc_ids = doc_ids; int size = doc_ids.length; doc_len = 0.0; for (int i = 0; i < size; i++) { updateTermCountMap(doc_ids[i], weights[i]); } try { collectionStats = searcher.collectionStatistics(field); token_count = collectionStats.sumTotalTermFreq(); } catch (IOException e) { e.printStackTrace(); System.exit(1); } } private double getDocumentTermProb(String termText) { if (termCounts.containsKey(termText)) { double tf = termCounts.get(termText); return (tf + 0.0) / (doc_len + 0.0); } else { System.out.println("Term does not occur in document."); return 0.0; } } private double getDocumentTermCount(String termText) { if (termCounts.containsKey(termText)) { double tf = termCounts.get(termText); return (tf + 0.0); } else { System.out.println("Term does not occur in document."); return 0.0; } } @SuppressWarnings("WeakerAccess") public double getCollectionTermProb(String termText) { double prob = 0.0; try { Term termInstance = new Term(field, termText); long termFreq = reader.totalTermFreq(termInstance); prob = (termFreq + 0.0) / (token_count + 1.0); } catch (IOException e) { e.printStackTrace(); System.exit(1); } return prob; } private void updateTermCountMap(int doc_id, double weight) { try { Terms t = reader.getTermVector(doc_id, field); if ((t != null) && (t.size() > 0)) { TermsEnum te = t.iterator(); BytesRef term; PostingsEnum p = null; while ((term = te.next()) != null) { String termText = term.utf8ToString(); if (termCounts.containsKey(termText)) { double v = termCounts.get(termText); termCounts.put(termText, v + (te.totalTermFreq() * weight)); } else { termCounts.put(termText, (te.totalTermFreq() * weight)); } doc_len = doc_len + (te.totalTermFreq() * weight); p = te.postings(p, PostingsEnum.ALL); } } } catch (IOException e) { e.printStackTrace(); System.exit(1); } } @SuppressWarnings("WeakerAccess") public double getJMTermProb(String termText, double lambda) { return (lambda * getDocumentTermProb(termText)) + (1 - lambda) * getCollectionTermProb(termText); } @SuppressWarnings("WeakerAccess") public double getDirichletTermProb(String termText, double mu) { return (getDocumentTermCount(termText) + mu * getCollectionTermProb(termText)) / (doc_len + mu); } public double KLDivergence(double lambda) { // Grab the vocabulary. TermsSet terms = TermsSet.getInstance(reader); double klDiv = 0.0; for (String term : terms) { if (!termCounts.containsKey(term)) { continue; } double px = getJMTermProb(term, lambda); double qx = getCollectionTermProb(term); klDiv += px * Math.log(px / qx); } return klDiv; } public void printTermVector() { double tProb = 0.0; double tCount = 0.0; for (Map.Entry m : termCounts.entrySet()) { String termText = (String) m.getKey(); double count = getDocumentTermCount(termText); tCount = tCount + count; double prob = getDocumentTermProb(termText); double cProb = getCollectionTermProb(termText); double jmProb = getJMTermProb(termText, 0.5); double dirProb = getDirichletTermProb(termText, 100); System.out.println(m.getKey() + " " + m.getValue() + " " + prob + " " + cProb + " " + jmProb + " " + dirProb); tProb = tProb + prob; } System.out.println("Total prob mass: " + tProb + " total term count:" + tCount + " Doc size:" + doc_len); } }