package fasttext.mmap; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import fasttext.*; import fasttext.store.MMapFile; import fasttext.store.ResourceInput; import fasttext.store.util.ByteUtils; import java.io.*; import java.nio.charset.StandardCharsets; import java.util.*; /** Memory-mapped dictionary implementation of {@link BaseDictionary} */ public class MMapDictionary extends BaseDictionary { private final MMapFile mmapFile; private final long entriesPositionOffset; private final int wordByteArrayLength; private final int subwordsByteArrayLength; protected final long[] wordHashes; protected final int[] ids; protected final int[] pruneKeys; protected final int[] pruneValues; private ResourceInput in; private MMapDictionary(Args args, int size, int nWords, int nLabels, long nTokens, int pruneIdxSize, MMapFile mmapFile, ResourceInput in, long entriesPositionOffset, int wordByteArrayLength, int subwordsByteArrayLength, long[] wordHashes, int[] ids, int[] pruneKeys, int[] pruneValues) { super(args, size, nWords, nLabels, nTokens, pruneIdxSize); this.mmapFile = mmapFile; this.in = in; this.entriesPositionOffset = entriesPositionOffset; this.wordByteArrayLength = wordByteArrayLength; this.subwordsByteArrayLength = subwordsByteArrayLength; this.wordHashes = wordHashes; this.ids = ids; this.pruneKeys = pruneKeys; this.pruneValues = pruneValues; // ngrams are already initialized initTableDiscard(); } private int entryByteArrayLength() { return wordByteArrayLength + subwordsByteArrayLength + Integer.BYTES + Byte.BYTES + Long.BYTES + Integer.BYTES; } /** * Position of an entry based on its id. Returns the position. */ private long entryPosition(int id) { return entriesPositionOffset + (long) entryByteArrayLength() * id; } /** * Position of an entry field based on the entry's id and the field's offset. */ private long entryFieldPosition(int id, int offset) { return entryPosition(id) + (long) offset; } /** * Offset to access to the word in the entry byte array. * Word is the first element of the byte array. */ private long wordOffset() { return 0L; } /** * Offset to access to the entry type in the entry byte array. * Consists in: Integer.BYTES word length + word length */ private int typeOffset() { return Integer.BYTES + wordByteArrayLength; } /** * Offset to access to the count value in the entry byte array * Consists in: Integer.BYTES word length + word length + Byte.BYTES type encoding */ private int countOffset() { return Integer.BYTES + Long.BYTES + wordByteArrayLength; } /** * Offset to access to the subwords array in the entry byte array * Consists in: Integer.BYTES word length + word length + Byte.BYTES type encoding + Long.BYTES count */ private int subwordsOffset() { return Integer.BYTES + wordByteArrayLength + Byte.BYTES + Long.BYTES; } private void position(long pos) { try { in.seek(pos); } catch (IOException ex) { throw new IllegalArgumentException("Could not seek position " + pos); } } private EntryType readType() { try { return EntryType.fromValue(in.readByteAsInt()); } catch (IOException ex) { throw new IllegalArgumentException("Could not read bytes to EntryType"); } } private String readWord() { try { int currWordLength = in.readInt(); byte[] barr = new byte[wordByteArrayLength]; in.readBytes(barr, 0, wordByteArrayLength); return new String(barr, 0, currWordLength, StandardCharsets.UTF_8); } catch (IOException ex) { throw new IllegalArgumentException("Could not read bytes to String"); } } private long readCount() { try { return in.readLong(); } catch (IOException ex) { throw new IllegalArgumentException("Could not read bytes to long"); } } private int[] readSubwords() { try { int currSubwordsSize = in.readInt(); byte[] barr = new byte[subwordsByteArrayLength]; in.readBytes(barr, 0, subwordsByteArrayLength); return ByteUtils.getIntArray(barr, 0, currSubwordsSize); } catch (IOException ex) { throw new IllegalArgumentException("Could not read bytes to array of ints"); } } private Entry readEntry() { Entry e = new Entry(); e.setWord(readWord()); e.setCount(readCount()); e.setType(readType()); e.setSubwords(Ints.asList(readSubwords())); return e; } @Override protected int hashToId(long h) { int idx; idx = Arrays.binarySearch(wordHashes, h); if (idx >= 0) { return ids[idx]; } return WORD_ID_DEFAULT; } @Override protected int getPruning(int id) { int idx = Arrays.binarySearch(pruneKeys, id); if (idx >= 0) { return pruneValues[idx]; } return -1; } @Override public Entry getEntry(int id) { Preconditions.checkPositionIndex(id, size); position(entryPosition(id)); return readEntry(); } @Override public EntryType getType(int id) { Preconditions.checkPositionIndex(id, size); position(entryFieldPosition(id, typeOffset())); return readType(); } @Override public String getWord(int id) { Preconditions.checkPositionIndex(id, nWords); position(entryPosition(id)); return readWord(); } @Override public String getLabel(int lid) { Preconditions.checkPositionIndex(lid, nLabels); position(entryPosition(lid + nWords)); return readWord(); } @Override public long getCount(int id) { Preconditions.checkPositionIndex(id, size); position(entryFieldPosition(id, countOffset())); return readCount(); } @Override public List<Integer> getSubwords(int id) { Preconditions.checkPositionIndex(id, size); position(entryFieldPosition(id, subwordsOffset())); return Ints.asList(readSubwords()); } @Override public Entry[] getEntries() { Entry[] words = new Entry[size]; for (int id = 0; id < size; id++) { words[id] = getEntry(id); } return words; } public static MMapDictionary load(Args args, MMapFile mmap) throws IOException { ResourceInput in = mmap.openInput(); // dictionary mmap utilities int wordByteArrayLength = in.readInt(); int subwordsByteArrayLength = in.readInt(); // dictionary meta data int size = in.readInt(); int nWords = in.readInt(); int nLabels = in.readInt(); long nTokens = in.readLong(); int pruneIdxSize = (int) in.readLong(); int pruneArrSize = Math.max(0, pruneIdxSize); int[] pruneKeys = new int[pruneArrSize]; int[] pruneValues = new int[pruneArrSize]; if (pruneIdxSize > 0) { for (int i = 0; i < pruneIdxSize; i++) { pruneKeys[i] = in.readInt(); } for (int i = 0; i < pruneIdxSize; i++) { pruneValues[i] = in.readInt(); } } // word2int long[] wordHashes = new long[size]; int[] ids = new int[size]; for (int i = 0; i < size; i++) { wordHashes[i] = in.readLong(); } for (int i = 0; i < size; i++) { ids[i] = in.readInt(); } int entriesPositionOffset = 36 + 8 * pruneArrSize + 12 * size; return new MMapDictionary(args, size, nWords, nLabels, nTokens, pruneIdxSize, mmap, in, entriesPositionOffset, wordByteArrayLength, subwordsByteArrayLength, wordHashes, ids, pruneKeys, pruneValues); } @Override public MMapDictionary clone() throws CloneNotSupportedException { MMapDictionary d = (MMapDictionary) super.clone(); d.in = in.clone(); return d; } public void close() throws IOException { in.close(); } public void saveToMMap(OutputStream os) throws IOException { throw new UnsupportedOperationException("Not implemented yet"); } }