/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.kylin.dict; import java.io.ByteArrayInputStream; import java.io.DataInput; import java.io.DataInputStream; import java.io.DataOutput; import java.io.IOException; import java.io.PrintStream; import java.lang.ref.SoftReference; import java.util.Arrays; import java.util.HashMap; import org.apache.kylin.common.util.BytesUtil; import org.apache.kylin.common.util.ClassUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A dictionary based on Trie data structure that maps enumerations of byte[] to * int IDs. * * With Trie the memory footprint of the mapping is kinda minimized at the cost * CPU, if compared to HashMap of ID Arrays. Performance test shows Trie is * roughly 10 times slower, so there's a cache layer overlays on top of Trie and * gracefully fall back to Trie using a weak reference. * * The implementation is thread-safe. * * @author yangli9 */ @SuppressWarnings({ "rawtypes", "unchecked" }) public class TrieDictionary<T> extends Dictionary<T> { public static final byte[] HEAD_MAGIC = new byte[] { 0x54, 0x72, 0x69, 0x65, 0x44, 0x69, 0x63, 0x74 }; // "TrieDict" public static final int HEAD_SIZE_I = HEAD_MAGIC.length; public static final int BIT_IS_LAST_CHILD = 0x80; public static final int BIT_IS_END_OF_VALUE = 0x40; private static final Logger logger = LoggerFactory.getLogger(TrieDictionary.class); private byte[] trieBytes; // non-persistent part transient private int headSize; @SuppressWarnings("unused") transient private int bodyLen; transient private int sizeChildOffset; transient private int sizeNoValuesBeneath; transient private int baseId; transient private int maxValueLength; transient private BytesConverter<T> bytesConvert; transient private int nValues; transient private int sizeOfId; transient private int childOffsetMask; transient private int firstByteOffset; transient private boolean enableCache = true; transient private SoftReference<HashMap> valueToIdCache; transient private SoftReference<Object[]> idToValueCache; public TrieDictionary() { // default constructor for Writable interface } public TrieDictionary(byte[] trieBytes) { init(trieBytes); } private void init(byte[] trieBytes) { this.trieBytes = trieBytes; if (BytesUtil.compareBytes(HEAD_MAGIC, 0, trieBytes, 0, HEAD_MAGIC.length) != 0) throw new IllegalArgumentException("Wrong file type (magic does not match)"); try { DataInputStream headIn = new DataInputStream( // new ByteArrayInputStream(trieBytes, HEAD_SIZE_I, trieBytes.length - HEAD_SIZE_I)); this.headSize = headIn.readShort(); this.bodyLen = headIn.readInt(); this.sizeChildOffset = headIn.read(); this.sizeNoValuesBeneath = headIn.read(); this.baseId = headIn.readShort(); this.maxValueLength = headIn.readShort(); String converterName = headIn.readUTF(); if (converterName.isEmpty() == false) this.bytesConvert = (BytesConverter<T>) ClassUtil.forName(converterName, BytesConverter.class).newInstance(); this.nValues = BytesUtil.readUnsigned(trieBytes, headSize + sizeChildOffset, sizeNoValuesBeneath); this.sizeOfId = BytesUtil.sizeForValue(baseId + nValues + 1); // note baseId could raise 1 byte in ID space, +1 to reserve all 0xFF for NULL case this.childOffsetMask = ~((BIT_IS_LAST_CHILD | BIT_IS_END_OF_VALUE) << ((sizeChildOffset - 1) * 8)); this.firstByteOffset = sizeChildOffset + sizeNoValuesBeneath + 1; // the offset from begin of node to its first value byte } catch (Exception e) { if (e instanceof RuntimeException) throw (RuntimeException) e; else throw new RuntimeException(e); } if (enableCache) { valueToIdCache = new SoftReference<HashMap>(new HashMap()); idToValueCache = new SoftReference<Object[]>(new Object[nValues]); } } @Override public int getMinId() { return baseId; } @Override public int getMaxId() { return baseId + nValues - 1; } @Override public int getSizeOfId() { return sizeOfId; } @Override public int getSizeOfValue() { return maxValueLength; } @Override final protected int getIdFromValueImpl(T value, int roundingFlag) { if (enableCache && roundingFlag == 0) { HashMap cache = valueToIdCache.get(); // SoftReference to skip cache // gracefully when short of // memory if (cache != null) { Integer id = null; id = (Integer) cache.get(value); if (id != null) return id.intValue(); byte[] valueBytes = bytesConvert.convertToBytes(value); id = getIdFromValueBytes(valueBytes, 0, valueBytes.length, roundingFlag); cache.put(value, id); return id; } } byte[] valueBytes = bytesConvert.convertToBytes(value); return getIdFromValueBytes(valueBytes, 0, valueBytes.length, roundingFlag); } @Override protected int getIdFromValueBytesImpl(byte[] value, int offset, int len, int roundingFlag) { int seq = lookupSeqNoFromValue(headSize, value, offset, offset + len, roundingFlag); int id = calcIdFromSeqNo(seq); if (id < 0) throw new IllegalArgumentException("Not a valid value: " + bytesConvert.convertFromBytes(value, offset, len)); return id; } /** * returns a code point from [0, nValues), preserving order of value * * @param n * -- the offset of current node * @param inp * -- input value bytes to lookup * @param o * -- offset in the input value bytes matched so far * @param inpEnd * -- end of input * @param roundingFlag * -- =0: return -1 if not found -- <0: return closest smaller if * not found, might be -1 -- >0: return closest bigger if not * found, might be nValues */ private int lookupSeqNoFromValue(int n, byte[] inp, int o, int inpEnd, int roundingFlag) { if (inp.length == 0) // special 'empty' value return checkFlag(headSize, BIT_IS_END_OF_VALUE) ? 0 : roundSeqNo(roundingFlag, -1, -1, 0); int seq = 0; // the sequence no under track while (true) { // match the current node, note [0] of node's value has been matched // when this node is selected by its parent int p = n + firstByteOffset; // start of node's value int end = p + BytesUtil.readUnsigned(trieBytes, p - 1, 1); // end of // node's // value for (p++; p < end && o < inpEnd; p++, o++) { // note matching start // from [1] if (trieBytes[p] != inp[o]) { int comp = BytesUtil.compareByteUnsigned(trieBytes[p], inp[o]); if (comp < 0) { seq += BytesUtil.readUnsigned(trieBytes, n + sizeChildOffset, sizeNoValuesBeneath); } return roundSeqNo(roundingFlag, seq - 1, -1, seq); // mismatch } } // node completely matched, is input all consumed? boolean isEndOfValue = checkFlag(n, BIT_IS_END_OF_VALUE); if (o == inpEnd) { return p == end && isEndOfValue ? seq : roundSeqNo(roundingFlag, seq - 1, -1, seq); // input // all // matched } if (isEndOfValue) seq++; // find a child to continue int c = headSize + (BytesUtil.readUnsigned(trieBytes, n, sizeChildOffset) & childOffsetMask); if (c == headSize) // has no children return roundSeqNo(roundingFlag, seq - 1, -1, seq); // input only // partially // matched byte inpByte = inp[o]; int comp; while (true) { p = c + firstByteOffset; comp = BytesUtil.compareByteUnsigned(trieBytes[p], inpByte); if (comp == 0) { // continue in the matching child, reset n and // loop again n = c; o++; break; } else if (comp < 0) { // try next child seq += BytesUtil.readUnsigned(trieBytes, c + sizeChildOffset, sizeNoValuesBeneath); if (checkFlag(c, BIT_IS_LAST_CHILD)) return roundSeqNo(roundingFlag, seq - 1, -1, seq); // no // child // can // match // the // next // byte // of // input c = p + BytesUtil.readUnsigned(trieBytes, p - 1, 1); } else { // children are ordered by their first value byte return roundSeqNo(roundingFlag, seq - 1, -1, seq); // no // child // can // match // the // next // byte // of // input } } } } private int roundSeqNo(int roundingFlag, int i, int j, int k) { if (roundingFlag == 0) return j; else if (roundingFlag < 0) return i; else return k; } @Override final protected T getValueFromIdImpl(int id) { if (enableCache) { Object[] cache = idToValueCache.get(); // SoftReference to skip // cache gracefully when // short of memory if (cache != null) { int seq = calcSeqNoFromId(id); if (seq < 0 || seq >= nValues) throw new IllegalArgumentException("Not a valid ID: " + id); if (cache[seq] != null) return (T) cache[seq]; byte[] value = new byte[getSizeOfValue()]; int length = getValueBytesFromId(id, value, 0); T result = bytesConvert.convertFromBytes(value, 0, length); cache[seq] = result; return result; } } byte[] value = new byte[getSizeOfValue()]; int length = getValueBytesFromId(id, value, 0); return bytesConvert.convertFromBytes(value, 0, length); } @Override protected int getValueBytesFromIdImpl(int id, byte[] returnValue, int offset) { if (id < baseId || id >= baseId + nValues) throw new IllegalArgumentException("Not a valid ID: " + id); int seq = calcSeqNoFromId(id); return lookupValueFromSeqNo(headSize, seq, returnValue, offset); } /** * returns a code point from [0, nValues), preserving order of value, or -1 * if not found * * @param n * -- the offset of current node * @param seq * -- the code point under track * @param returnValue * -- where return value is written to */ private int lookupValueFromSeqNo(int n, int seq, byte[] returnValue, int offset) { int o = offset; while (true) { // write current node value int p = n + firstByteOffset; int len = BytesUtil.readUnsigned(trieBytes, p - 1, 1); System.arraycopy(trieBytes, p, returnValue, o, len); o += len; // if the value is ended boolean isEndOfValue = checkFlag(n, BIT_IS_END_OF_VALUE); if (isEndOfValue) { seq--; if (seq < 0) return o - offset; } // find a child to continue int c = headSize + (BytesUtil.readUnsigned(trieBytes, n, sizeChildOffset) & childOffsetMask); if (c == headSize) // has no children return -1; // no child? corrupted dictionary! int nValuesBeneath; while (true) { nValuesBeneath = BytesUtil.readUnsigned(trieBytes, c + sizeChildOffset, sizeNoValuesBeneath); if (seq - nValuesBeneath < 0) { // value is under this child, // reset n and loop again n = c; break; } else { // go to next child seq -= nValuesBeneath; if (checkFlag(c, BIT_IS_LAST_CHILD)) return -1; // no more child? corrupted dictionary! p = c + firstByteOffset; c = p + BytesUtil.readUnsigned(trieBytes, p - 1, 1); } } } } private boolean checkFlag(int offset, int bit) { return (trieBytes[offset] & bit) > 0; } private int calcIdFromSeqNo(int seq) { if (seq < 0 || seq >= nValues) return -1; else return baseId + seq; } private int calcSeqNoFromId(int id) { return id - baseId; } @Override public void write(DataOutput out) throws IOException { out.write(trieBytes); } @Override public void readFields(DataInput in) throws IOException { byte[] headPartial = new byte[HEAD_MAGIC.length + Short.SIZE + Integer.SIZE]; in.readFully(headPartial); if (BytesUtil.compareBytes(HEAD_MAGIC, 0, headPartial, 0, HEAD_MAGIC.length) != 0) throw new IllegalArgumentException("Wrong file type (magic does not match)"); DataInputStream headIn = new DataInputStream( // new ByteArrayInputStream(headPartial, HEAD_SIZE_I, headPartial.length - HEAD_SIZE_I)); int headSize = headIn.readShort(); int bodyLen = headIn.readInt(); headIn.close(); byte[] all = new byte[headSize + bodyLen]; System.arraycopy(headPartial, 0, all, 0, headPartial.length); in.readFully(all, headPartial.length, all.length - headPartial.length); init(all); } @Override public void dump(PrintStream out) { out.println("Total " + nValues + " values"); for (int i = 0; i < nValues; i++) { int id = calcIdFromSeqNo(i); T value = getValueFromId(id); out.println(id + " (" + Integer.toHexString(id) + "): " + value); } } @Override public int hashCode() { return Arrays.hashCode(trieBytes); } @Override public boolean equals(Object o) { if ((o instanceof TrieDictionary) == false) { logger.info("Equals return false because o is not TrieDictionary"); return false; } TrieDictionary that = (TrieDictionary) o; return Arrays.equals(this.trieBytes, that.trieBytes); } public static void main(String[] args) throws Exception { TrieDictionaryBuilder<String> b = new TrieDictionaryBuilder<String>(new StringBytesConverter()); // b.addValue("part"); // b.print(); // b.addValue("part"); // b.print(); // b.addValue("par"); // b.print(); // b.addValue("partition"); // b.print(); // b.addValue("party"); // b.print(); // b.addValue("parties"); // b.print(); // b.addValue("paint"); // b.print(); b.addValue("-000000.41"); b.addValue("0000101.81"); b.addValue("6779331"); String t = "0000001.6131"; TrieDictionary<String> dict = b.build(0); System.out.println(dict.getIdFromValue(t, -1)); System.out.println(dict.getIdFromValue(t, 1)); } }