package slp.core.counting.trie; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.stream.Collectors; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; public class MapTrieCounter extends AbstractTrie { /** * 'counts' contains in order: own count, context count (sum of successor's counts), * no of distinct successors seen once, twice, up to the COCcutoff in Configuration */ private Int2ObjectMap<Object> map; private IntList pseudoOrdering; // Maximum depth in trie to use Map-tries, after this Array-Tries are used, which are slower but more memory-efficient private static final int MAX_DEPTH_MAP_TRIE = 1; public MapTrieCounter() { this(1); } public MapTrieCounter(int initSize) { super(); this.map = new Int2ObjectOpenHashMap<>(initSize); this.map.defaultReturnValue(null); this.pseudoOrdering = new IntArrayList(); } @Override public List<Integer> getSuccessors() { return this.map.keySet().stream().collect(Collectors.toList()); } private static Map<Integer, Integer> cache = new HashMap<>(); @Override public List<Integer> getTopSuccessorsInternal(int limit) { int classKey = this.hashCode(); int countsKey = this.keyCode(); Integer cached = cache.get(classKey); if (cached == null || cached != countsKey) { Collections.sort(this.pseudoOrdering, (i1, i2) -> compareCounts(i1, i2)); } int end = Math.min(this.pseudoOrdering.size(), limit); List<Integer> topSuccessors = new ArrayList<>(this.pseudoOrdering.subList(0, end)); if (this.getSuccessorCount() > 10) cache.put(classKey, countsKey); return topSuccessors; } @Override public int hashCode() { return super.hashCode(); } private int keyCode() { return 31*(this.getSuccessorCount() + 31*this.getCount()); } @Override AbstractTrie makeNext(int depth) { AbstractTrie newNext; if (depth <= MAX_DEPTH_MAP_TRIE) newNext = new MapTrieCounter(1); else newNext = new ArrayTrieCounter(); return newNext; } @Override public Object getSuccessor(int next) { return this.map.get(next); } @Override void putSuccessor(int next, Object o) { Object curr = this.map.put(next, o); if (curr == null) this.pseudoOrdering.add(next); } private int compareCounts(Integer i1, Integer i2) { int base = -Integer.compare(getCount(this.map.get((int) i1)), getCount(this.map.get((int) i2))); if (base != 0) return base; return Integer.compare(i1, i2); } @Override void removeSuccessor(int next) { Object removed = this.map.remove(next); this.pseudoOrdering.rem(next); if (removed instanceof MapTrieCounter) { cache.remove(((MapTrieCounter) removed).hashCode()); } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.counts = new int[2 + COUNT_OF_COUNTS_CUTOFF]; this.counts[0] = in.readInt(); this.counts[1] = in.readInt(); int successors = in.readInt(); this.map = new Int2ObjectOpenHashMap<>(successors, 0.9f); int pos = 0; for (; pos < successors; pos++) { int key = in.readInt(); int code = in.readInt(); Object value; if (code < 0) { if (code < -1) value = new ArrayTrieCounter(); else value = new MapTrieCounter(); ((AbstractTrie) value).readExternal(in); this.counts[1 + Math.min(((AbstractTrie) value).getCount(), COUNT_OF_COUNTS_CUTOFF)]++; } else { value = new int[code]; for (int j = 0; j < code; j++) ((int[]) value)[j] = in.readInt(); this.counts[1 + Math.min(((int[]) value)[0], COUNT_OF_COUNTS_CUTOFF)]++; } this.putSuccessor(key, value); } } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(this.counts[0]); out.writeInt(this.counts[1]); out.writeInt(this.map.size()); for (Entry<Integer, Object> entry : this.map.int2ObjectEntrySet()) { int key = entry.getKey(); Object value = entry.getValue(); out.writeInt(key); Object o = value; if (o instanceof int[]) { int[] arr = (int[]) o; out.writeInt(arr.length); for (int j = 0; j < arr.length; j++) out.writeInt(arr[j]); } else { if (o instanceof ArrayTrieCounter) out.writeInt(-2); else out.writeInt(-1); ((AbstractTrie) o).writeExternal(out); } } } }