package com.indeed.mph.serializers; import com.indeed.mph.LinearDiophantineEquation; import com.indeed.mph.TableWriter; import com.indeed.util.io.Files; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import it.unimi.dsi.sux4j.mph.GOV4Function; import java.io.DataInput; import java.io.DataOutput; import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.Arrays; import java.util.Map; /** * General serializer for any "dictionary" of terms, storing the terms * as unique ids. This is useful for enums or fixed sets of terms * which are repeated, but if most inputs are unique it's more * efficient to just serialize as a string. The same serializer (or * equivalent after serialization) must be used to deserialize, or the * ids will not match. * * @author alexs */ public class SmartDictionarySerializer extends AbstractSmartSerializer<String> { private static final long serialVersionUID = 2138609301; private final SmartVLongSerializer serializer = new SmartVLongSerializer(); private Object2IntMap<String> dictionary = new Object2IntOpenHashMap<>(); private GOV4Function<String> mphFunction; private String[] words; private boolean onlyUsedInValue; public SmartDictionarySerializer() { this(false); } public SmartDictionarySerializer(final boolean onlyUsedInValue) { this.onlyUsedInValue = onlyUsedInValue; this.mphFunction = null; } @Override public String parseFromString(final String str) throws IOException { return str; } @Override public void write(final String str, final DataOutput out) throws IOException { serializer.write(getIndex(str), out); } @Override public String read(final DataInput in) throws IOException { final Long n = serializer.read(in); if (words == null) { synchronized (this) { if (words == null) { words = dictionaryToIndex(dictionary); } } } if (n < 0 || n >= words.length) { throw new IOException("read unknown serialized id: " + n); } return words[n.intValue()]; } @Override public LinearDiophantineEquation size() { return serializer.size(); } private long getIndex(final String str) throws IOException { if (mphFunction != null) { final long index = mphFunction.getLong(str); // Validate the index and string matched. if (index < 0 || index >= words.length || !str.equals(words[(int)(index)])) { return -1; } return index; } if (dictionary == null) { if (words == null) { throw new IOException("invalid dictionary, flat and mapped indexes both null"); } synchronized (this) { if (dictionary == null) { dictionary = indexToDictionary(words); } } } final Integer n = dictionary.get(str); if (n == null) { synchronized (this) { final Integer n2 = dictionary.get(str); if (n2 != null) { return n2; } final int result = dictionary.size(); dictionary.put(str, result); words = null; // invalidate current index return result; } } return n; } private Object2IntMap<String> indexToDictionary(final String[] words) throws IOException { final Object2IntMap<String> result = new Object2IntOpenHashMap<>(); for (int i = 0; i < words.length; ++i) { result.put(words[i], i); } return result; } private String[] dictionaryToIndex(final Object2IntMap<String> dict) throws IOException { final String[] result = new String[dict.size()]; for (final Map.Entry<String, Integer> entry : dict.entrySet()) { final String word = entry.getKey(); final Integer index = entry.getValue(); if (index == null || index < 0 || index >= result.length) { throw new IOException("inconsistent dictionary, has " + result.length + " entries but an index of " + word + " -> " + index); } if (result[index] != null) { throw new IOException("inconsistent dictionary, both " + result[index] + " and " + word + " map to " + index); } result[index] = word; } return result; } private GOV4Function<String> buildMphFunction() throws IOException { final File tempFolder = File.createTempFile("smartDictionarySerializer", ".tmp"); if (!Files.delete(tempFolder.getAbsolutePath())) { throw new IOException("Can't delete tempFolder: " + tempFolder); } if (!tempFolder.mkdir()) { throw new IOException("Can't create tempFolder: " + tempFolder); } return new GOV4Function.Builder<String>() .keys(Arrays.asList(words)) .tempDir(tempFolder) .transform(new TableWriter.SerializerTransformationStrategy<>(new SmartStringSerializer())) .build(); } // use default serialization, but compact to just the flat index first private void writeObject(final ObjectOutputStream outputStream) throws IOException { if (words == null) { words = dictionaryToIndex(dictionary); } dictionary = null; if (onlyUsedInValue) { mphFunction = null; } else { mphFunction = buildMphFunction(); } outputStream.defaultWriteObject(); // To support this serializer used in many configs. mphFunction = null; } private void readObject(final ObjectInputStream inputStream) throws IOException, ClassNotFoundException { inputStream.defaultReadObject(); if (words == null) { throw new IOException("words can't be null"); } if (!onlyUsedInValue && mphFunction == null) { mphFunction = buildMphFunction(); } } }