/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.ftvec.hashing; import hivemall.HivemallConstants; import hivemall.UDFWithOptions; import hivemall.annotations.VisibleForTesting; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.MurmurHash3; import hivemall.utils.lang.Primitives; import hivemall.utils.lang.StringUtils; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; //@formatter:off @Description(name = "feature_hashing", value = "_FUNC_(array<string> features [, const string options])" + " - returns a hashed feature vector in array<string>", extended = "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm');\n" + " [\"4063537:1.0\",\"4063537:1\",\"8459207:2.0\"]\n" + "\n" + "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10');\n" + " [\"7:1.0\",\"7\",\"1:2.0\"]\n" + "\n" + "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm');\n" + " [\"1:2.0\",\"7:1.0\",\"7:1\"]\n" + "") //@formatter:on @UDFType(deterministic = true, stateful = false) public final class FeatureHashingUDF extends UDFWithOptions { private static final IndexComparator indexCmp = new IndexComparator(); @Nullable private ListObjectInspector _listOI; private boolean _libsvmFormat = false; private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES; @Nullable private transient List<String> _returnObj; public FeatureHashingUDF() {} @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("libsvm", false, "Returns in libsvm format (<index>:<value>)* sorted by index ascending order"); opts.addOption("features", "num_features", true, "The number of features [default: 16777217 (2^24)]"); return opts; } @Override protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { CommandLine cl = parseOptions(optionValue); this._libsvmFormat = cl.hasOption("libsvm"); this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), _numFeatures); return cl; } @Override public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 1 && argOIs.length != 2) { showHelp("The feature_hashing function takes 1 or 2 arguments: " + argOIs.length); } ObjectInspector argOI0 = argOIs[0]; this._listOI = HiveUtils.isListOI(argOI0) ? (ListObjectInspector) argOI0 : null; if (argOIs.length == 2) { String opts = HiveUtils.getConstString(argOIs[1]); processOptions(opts); } if (_listOI == null) { return PrimitiveObjectInspectorFactory.javaStringObjectInspector; } else { return ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector); } } @Override public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { final Object arg0 = arguments[0].get(); if (arg0 == null) { return null; } if (_listOI == null) { return evaluateScalar(arg0); } else { return evaluateList(arg0); } } @Nonnull private String evaluateScalar(@Nonnull final Object arg0) { String fv = arg0.toString(); return featureHashing(fv, _numFeatures, _libsvmFormat); } @Nonnull private List<String> evaluateList(@Nonnull final Object arg0) throws HiveException { final int len = _listOI.getListLength(arg0); List<String> list = _returnObj; if (list == null) { list = new ArrayList<String>(len); this._returnObj = list; } else { list.clear(); } final int numFeatures = _numFeatures; for (int i = 0; i < len; i++) { Object obj = _listOI.getListElement(arg0, i); if (obj == null) { continue; } String fv = featureHashing(obj.toString(), numFeatures, _libsvmFormat); list.add(fv); } if (_libsvmFormat) { try { Collections.sort(list, indexCmp); } catch (NumberFormatException e) { throw new HiveException(e); } } return list; } @VisibleForTesting @Nonnull static String featureHashing(@Nonnull final String fv, final int numFeatures) { return featureHashing(fv, numFeatures, false); } @Nonnull static String featureHashing(@Nonnull final String fv, final int numFeatures, final boolean libsvmFormat) { final int headPos = fv.indexOf(':'); if (headPos == -1) { if (fv.equals(HivemallConstants.BIAS_CLAUSE)) { return fv; } final int h = mhash(fv, numFeatures); if (libsvmFormat) { return h + ":1"; } else { return String.valueOf(h); } } else { final int tailPos = fv.lastIndexOf(':'); if (headPos == tailPos) { String f = fv.substring(0, headPos); String tail = fv.substring(headPos); if (f.equals(HivemallConstants.BIAS_CLAUSE)) { String v = fv.substring(headPos + 1); double d = Double.parseDouble(v); if (d == 1.d) { return fv; } } int h = mhash(f, numFeatures); return h + tail; } else { String field = fv.substring(0, headPos + 1); String f = fv.substring(headPos + 1, tailPos); int h = mhash(f, numFeatures); String v = fv.substring(tailPos); return field + h + v; } } } static int mhash(@Nonnull final String word, final int numFeatures) { int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures; if (r < 0) { r += numFeatures; } return r + 1; } @Override public String getDisplayString(String[] children) { return "feature_hashing(" + StringUtils.join(children, ',') + ')'; } private static final class IndexComparator implements Comparator<String>, Serializable { private static final long serialVersionUID = -260142385860586255L; @Override public int compare(@Nonnull final String lhs, @Nonnull final String rhs) { int l = getIndex(lhs); int r = getIndex(rhs); return Integer.compare(l, r); } private static int getIndex(@Nonnull final String fv) { final int headPos = fv.indexOf(':'); final int tailPos = fv.lastIndexOf(':'); final String f; if (headPos == tailPos) { f = fv.substring(0, headPos); } else { f = fv.substring(headPos + 1, tailPos); } return Integer.parseInt(f); } } }