package edu.uw.easysrl.syntax.tagger; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import com.google.common.collect.HashMultimap; import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multimap; import com.google.common.collect.Multiset; import com.google.common.collect.Multiset.Entry; import edu.uw.easysrl.main.InputReader.InputToParser; import edu.uw.easysrl.syntax.grammar.Category; import edu.uw.easysrl.syntax.grammar.SyntaxTreeNode; import edu.uw.easysrl.syntax.grammar.SyntaxTreeNode.SyntaxTreeNodeLeaf; import edu.uw.easysrl.util.Util; public class TagDict { private static final int MIN_OCCURENCES_OF_WORD = 500; /** * Key used in the tag dictionary for infrequent words */ public static final String OTHER_WORDS = "*other_words*"; private final static String fileName = "tagdict"; /** * Saves a tag dictionary to the model folder */ public static void writeTagDict(final Map<String, Collection<Category>> tagDict, final File file) throws FileNotFoundException, UnsupportedEncodingException { final PrintWriter writer = new PrintWriter(file, "UTF-8"); for (final java.util.Map.Entry<String, Collection<Category>> entry : tagDict.entrySet()) { writer.print(entry.getKey()); for (final Category c : entry.getValue()) { writer.print("\t" + c.toString()); } writer.println(); } writer.close(); } /** * Loads a tag dictionary from the model folder */ public static Map<String, Collection<Category>> readDict(final File modelFolder, final Set<Category> lexicalCategories) throws IOException { final Map<String, Collection<Category>> result = new HashMap<>(); // Hack so that annotation gets into the tag dict final Multimap<Category, Category> categoryToAnnotatedCategories = HashMultimap.create(); for (final Category category : lexicalCategories) { categoryToAnnotatedCategories.put(category.withoutAnnotation(), category); } final File file = new File(modelFolder, fileName); loadTagDict(lexicalCategories, result, file, false, categoryToAnnotatedCategories); // } if (result.size() == 0) { // No tag dictionaries available return null; } final File ccgbankTagDict = new File(modelFolder, fileName + ".ccgbank"); loadTagDict(lexicalCategories, result, ccgbankTagDict, true, categoryToAnnotatedCategories); return ImmutableMap.copyOf(result); } private static void loadTagDict(final Set<Category> lexicalCategories, final Map<String, Collection<Category>> result, final File file, final boolean skipIfNotPresent, final Multimap<Category, Category> categoryToAnnotatedCategories) throws IOException { if (!file.exists()) { return; } for (final String line : Util.readFile(file)) { final String[] fields = line.split("\t"); Collection<Category> cats = result.get(fields[0]); if (cats == null) { if (skipIfNotPresent) { continue; } else { cats = new HashSet<>(); } } for (int i = 1; i < fields.length; i++) { final Category cat = Category.valueOf(fields[i]); if (lexicalCategories.contains(cat)) { cats.addAll(categoryToAnnotatedCategories.get(cat)); } } if (cats.size() > 0) { result.put(fields[0], cats); } } } private final static Comparator<Entry<Category>> comparator = new Comparator<Entry<Category>>() { @Override public int compare(final Entry<Category> arg0, final Entry<Category> arg1) { return arg1.getCount() - arg0.getCount(); } }; /** * Finds the set of categories used for each word in a corpus */ public static Map<String, Collection<Category>> makeDict(final Iterable<InputToParser> input) { final Multiset<String> wordCounts = HashMultiset.create(); final Map<String, Multiset<Category>> wordToCatToCount = new HashMap<>(); // First, count how many times each word occurs with each category for (final InputToParser sentence : input) { for (int i = 0; i < sentence.getInputWords().size(); i++) { final String word = sentence.getInputWords().get(i).word; final Category cat = sentence.getGoldCategories().get(i); wordCounts.add(word); if (!wordToCatToCount.containsKey(word)) { final Multiset<Category> tmp = HashMultiset.create(); wordToCatToCount.put(word, tmp); } wordToCatToCount.get(word).add(cat); } } return makeDict(wordCounts, wordToCatToCount); } private static Map<String, Collection<Category>> makeDict(final Multiset<String> wordCounts, final Map<String, Multiset<Category>> wordToCatToCount) { // Now, save off a sorted list of categories final Multiset<Category> countsForOtherWords = HashMultiset.create(); final Map<String, Collection<Category>> result = new HashMap<>(); for (final Entry<String> wordAndCount : wordCounts.entrySet()) { final Multiset<Category> countForCategory = wordToCatToCount.get(wordAndCount.getElement()); if (wordAndCount.getCount() > MIN_OCCURENCES_OF_WORD) { // Frequent word addEntryForWord(countForCategory, result, wordAndCount.getElement()); } else { // Group stats for all rare words together. for (final Entry<Category> catToCount : countForCategory.entrySet()) { countsForOtherWords.add(catToCount.getElement(), catToCount.getCount()); } } } addEntryForWord(countsForOtherWords, result, OTHER_WORDS); return ImmutableMap.copyOf(result); } public static Map<String, Collection<Category>> makeDictFromParses(final Iterator<SyntaxTreeNode> input) { final Multiset<String> wordCounts = HashMultiset.create(); final Map<String, Multiset<Category>> wordToCatToCount = new HashMap<>(); int sentenceCount = 0; // First, count how many times each word occurs with each category while (input.hasNext()) { final SyntaxTreeNode sentence = input.next(); final List<SyntaxTreeNodeLeaf> leaves = sentence.getLeaves(); for (int i = 0; i < leaves.size(); i++) { final String word = leaves.get(i).getWord(); final Category cat = leaves.get(i).getCategory(); wordCounts.add(word); if (!wordToCatToCount.containsKey(word)) { final Multiset<Category> tmp = HashMultiset.create(); wordToCatToCount.put(word, tmp); } wordToCatToCount.get(word).add(cat); } sentenceCount++; if (sentenceCount % 100 == 0) { System.out.println(sentenceCount); } } return makeDict(wordCounts, wordToCatToCount); } private static void addEntryForWord(final Multiset<Category> countForCategory, final Map<String, Collection<Category>> result, final String word) { final List<Entry<Category>> cats = new ArrayList<>(); for (final Entry<Category> catToCount : countForCategory.entrySet()) { cats.add(catToCount); } final int totalSize = countForCategory.size(); final int minSize = Math.floorDiv(totalSize, 1000); Collections.sort(cats, comparator); final List<Category> cats2 = new ArrayList<>(); for (final Entry<Category> entry : cats) { if (entry.getCount() >= minSize) { cats2.add(entry.getElement()); } } result.put(word, cats2); } }