package edu.uw.easysrl.syntax.parser; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashBasedTable; import com.google.common.collect.HashMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.MinMaxPriorityQueue; import com.google.common.collect.Multimap; import com.google.common.collect.Table; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import edu.uw.easysrl.dependencies.UnlabelledDependency; import edu.uw.easysrl.syntax.grammar.Category; import edu.uw.easysrl.syntax.grammar.SyntaxTreeNode; import edu.uw.easysrl.syntax.model.AgendaItem; import edu.uw.easysrl.util.FastTreeMap; public abstract class ChartCell { /** * Possibly adds a @AgendaItem to this chart cell. Returns true if the parse was added, and false if the cell was * unchanged. */ public final boolean add(final AgendaItem entry) { return add(entry.getEquivalenceClassKey(), entry); } public static abstract class ChartCellFactory { public abstract ChartCell make(); /** * Get factory for a new sentence. */ public ChartCellFactory forNewSentence() { return this; } } public abstract boolean add(final Object key, final AgendaItem entry); public abstract Iterable<AgendaItem> getEntries(); public abstract int size(); /** * Chart Cell used for 1-best parsing. */ protected static class Cell1Best extends ChartCell { final Map<Object, AgendaItem> keyToProbability = new HashMap<>(); @Override public Collection<AgendaItem> getEntries() { return keyToProbability.values(); } @Override public boolean add(final Object key, final AgendaItem entry) { return keyToProbability.putIfAbsent(key, entry) == null; } @Override public int size() { return keyToProbability.size(); } public static ChartCellFactory factory() { return new ChartCellFactory() { @Override public ChartCell make() { return new Cell1Best(); } }; } } /** * ChartCell for A* parsing that uses a custom tree data structure, rather than a hash map. It'll die horribly if * the keys aren't comparable. */ protected static class Cell1BestTreeBased extends ChartCell { final FastTreeMap<Object, AgendaItem> keyToProbability = new FastTreeMap<>(); @Override public Iterable<AgendaItem> getEntries() { return keyToProbability.values(); } @Override public boolean add(final Object key, final AgendaItem entry) { return keyToProbability.putIfAbsent(key, entry); } @Override public int size() { return keyToProbability.size(); } public static ChartCellFactory factory() { return new ChartCellFactory() { @Override public ChartCell make() { return new Cell1BestTreeBased(); } }; } } /** * ChartCell for CKY parsing. The main difference with A* is that it needs to check if new entries have a higher * score than existing entries (which can't happen with A*). * */ protected static class Cell1BestCKY extends Cell1Best { @Override public Collection<AgendaItem> getEntries() { return keyToProbability.values(); } @Override public boolean add(final Object key, final AgendaItem entry) { final AgendaItem currentEntry = keyToProbability.get(key); if (currentEntry == null || entry.getInsideScore() > currentEntry.getInsideScore()) { keyToProbability.put(key, entry); return true; } else { return false; } } } /** * Allows a limited or unbounded number of items in a cell, without dividing them into equivalence classes. * * Could also be used in conjunction with dependency hashing? */ static class CellNoDynamicProgram extends ChartCell { private final Collection<AgendaItem> entries; CellNoDynamicProgram() { this.entries = new ArrayList<>(); } CellNoDynamicProgram(int nbest) { this.entries = MinMaxPriorityQueue.maximumSize(nbest).create(); } @Override public Collection<AgendaItem> getEntries() { return entries; } @Override public boolean add(final Object key, final AgendaItem newEntry) { return entries.add(newEntry); } @Override public int size() { return entries.size(); } public static ChartCellFactory factory() { return new ChartCellFactory() { @Override public ChartCell make() { return new CellNoDynamicProgram(); } }; } public static ChartCellFactory factory(final int nbest) { return new ChartCellFactory() { @Override public ChartCell make() { return new CellNoDynamicProgram(nbest); } }; } } /** * Implements dependency hashing for better N-best parsing, as in Ng&Curran 2012 */ public static class ChartCellNbestFactory extends ChartCellFactory { private final int nbest; private final double nbestBeam; public ChartCellNbestFactory(final int nbest, final double nbestBeam, final int maxSentenceLength, final Collection<Category> categories) { super(); this.nbest = nbest; this.nbestBeam = nbestBeam; final Random randomGenerator = new Random(); // Build a hash for every possible dependency categoryToArgumentToHeadToModifierToHash = HashBasedTable.create(); for (final Category c : categories) { for (int i = 1; i <= c.getNumberOfArguments(); i++) { final int[][] array = new int[maxSentenceLength][maxSentenceLength]; categoryToArgumentToHeadToModifierToHash.put(c, i, array); for (int head = 0; head < maxSentenceLength; head++) { for (int child = 0; child < maxSentenceLength; child++) { array[head][child] = randomGenerator.nextInt(); } } } } } public ChartCellNbestFactory(ChartCellNbestFactory other) { this.nbest = other.nbest; this.nbestBeam = other.nbestBeam; this.categoryToArgumentToHeadToModifierToHash = other.categoryToArgumentToHeadToModifierToHash; } // A cache of hash scores for nodes, to save recomputing them. I'm not in love with this design, but at least it // keeps all the hashing code in one place. private final Map<SyntaxTreeNode, Integer> nodeToHash = new HashMap<>(); private final Table<Category, Integer, int[][]> categoryToArgumentToHeadToModifierToHash; private int getHash(final SyntaxTreeNode parse) { Integer result = nodeToHash.get(parse); if (result == null) { result = 0; // Add in a hash for each dependency at this node. final List<UnlabelledDependency> resolvedUnlabelledDependencies = parse .getResolvedUnlabelledDependencies(); if (resolvedUnlabelledDependencies != null) { for (final UnlabelledDependency dep : resolvedUnlabelledDependencies) { for (final int arg : dep.getArguments()) { if (dep.getHead() != arg) { result = result ^ categoryToArgumentToHeadToModifierToHash.get(dep.getCategory(), dep.getArgNumber())[dep.getHead()][arg]; } } } } for (final SyntaxTreeNode child : parse.getChildren()) { result = result ^ getHash(child); } } return result; } /** * Chart Cell used for N-best parsing. It allows multiple entries with the same key, but doesn't check for * equivalence */ protected class CellNBest extends ChartCell { private final ListMultimap<Object, AgendaItem> keyToEntries = ArrayListMultimap.create(); @Override public Collection<AgendaItem> getEntries() { return keyToEntries.values(); } @Override public boolean add(final Object key, final AgendaItem newEntry) { final List<AgendaItem> existing = keyToEntries.get(key); if (existing.size() > nbest || (existing.size() > 0 && newEntry.getCost() < existing.get(0).getCost() + Math.log(nbestBeam))) { return false; } else { // Only cache out hashes for nodes that get added to the chart. keyToEntries.put(key, newEntry); return true; } } @Override public int size() { return keyToEntries.size(); } } /** * Chart Cell used for N-best parsing. It allows multiple entries with the same key, if they are not equivalent. */ class CellNBestWithHashing extends ChartCell { private final ListMultimap<Object, AgendaItem> keyToEntries = ArrayListMultimap.create(); private final Multimap<Object, Integer> keyToHash = HashMultimap.create(); @Override public Collection<AgendaItem> getEntries() { return keyToEntries.values(); } @Override public boolean add(final Object key, final AgendaItem newEntry) { final List<AgendaItem> existing = keyToEntries.get(key); if (existing.size() > nbest || (existing.size() > 0 && newEntry.getCost() < existing.get(0).getCost() + Math.log(nbestBeam))) { return false; } else { final Integer hash = getHash(newEntry.getParse()); if (keyToHash.containsEntry(key, hash)) { // Already have an equivalent node. return false; } keyToEntries.put(key, newEntry); keyToHash.put(key, hash); // Cache out hash for this parse. nodeToHash.put(newEntry.getParse(), hash); return true; } } @Override public int size() { return keyToEntries.size(); } } @Override public ChartCell make() { return // new CellNBest(); new CellNBestWithHashing(); } @Override public ChartCellFactory forNewSentence() { return new ChartCellNbestFactory(this); } } }