/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.topicmodel; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.CommandLineUtils; import hivemall.utils.lang.Primitives; import hivemall.utils.struct.KeySortablePair; import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; @Description(name = "plsa_predict", value = "_FUNC_(string word, float value, int label, float prob[, const string options])" + " - Returns a list which consists of <int label, float prob>") public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver { @Override public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { if (typeInfo.length != 4 && typeInfo.length != 5) { throw new UDFArgumentLengthException( "Expected argument length is 4 or 5 but given argument length was " + typeInfo.length); } if (!HiveUtils.isStringTypeInfo(typeInfo[0])) { throw new UDFArgumentTypeException(0, "String type is expected for the first argument word: " + typeInfo[0].getTypeName()); } if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) { throw new UDFArgumentTypeException(1, "Number type is expected for the second argument value: " + typeInfo[1].getTypeName()); } if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) { throw new UDFArgumentTypeException(2, "Integer type is expected for the third argument label: " + typeInfo[2].getTypeName()); } if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { throw new UDFArgumentTypeException(3, "Number type is expected for the forth argument prob: " + typeInfo[3].getTypeName()); } if (typeInfo.length == 5) { if (!HiveUtils.isStringTypeInfo(typeInfo[4])) { throw new UDFArgumentTypeException(4, "String type is expected for the fifth argument prob: " + typeInfo[4].getTypeName()); } } return new Evaluator(); } public static class Evaluator extends GenericUDAFEvaluator { // input OI private PrimitiveObjectInspector wordOI; private PrimitiveObjectInspector valueOI; private PrimitiveObjectInspector labelOI; private PrimitiveObjectInspector probOI; // Hyperparameters private int topics; private float alpha; private double delta; // merge OI private StructObjectInspector internalMergeOI; private StructField wcListField; private StructField probMapField; private StructField topicsOptionField; private StructField alphaOptionField; private StructField deltaOptionField; private PrimitiveObjectInspector wcListElemOI; private StandardListObjectInspector wcListOI; private StandardMapObjectInspector probMapOI; private PrimitiveObjectInspector probMapKeyOI; private StandardListObjectInspector probMapValueOI; private PrimitiveObjectInspector probMapValueElemOI; public Evaluator() {} protected Options getOptions() { Options opts = new Options(); opts.addOption("k", "topics", true, "The number of topics [default: 10]"); opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]"); opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-5]"); return opts; } @Nonnull protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { String[] args = optionValue.split("\\s+"); Options opts = getOptions(); opts.addOption("help", false, "Show function help"); CommandLine cl = CommandLineUtils.parseOptions(args, opts); if (cl.hasOption("help")) { Description funcDesc = getClass().getAnnotation(Description.class); final String cmdLineSyntax; if (funcDesc == null) { cmdLineSyntax = getClass().getSimpleName(); } else { String funcName = funcDesc.name(); cmdLineSyntax = funcName == null ? getClass().getSimpleName() : funcDesc.value().replace("_FUNC_", funcDesc.name()); } StringWriter sw = new StringWriter(); sw.write('\n'); PrintWriter pw = new PrintWriter(sw); HelpFormatter formatter = new HelpFormatter(); formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); pw.flush(); String helpMsg = sw.toString(); throw new UDFArgumentException(helpMsg); } return cl; } @Nullable protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { CommandLine cl = null; if (argOIs.length >= 5) { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = parseOptions(rawArgs); this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS); if (topics < 1) { throw new UDFArgumentException( "A positive integer MUST be set to an option `-topics`: " + topics); } this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA); this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA); } else { this.topics = PLSAUDTF.DEFAULT_TOPICS; this.alpha = PLSAUDTF.DEFAULT_ALPHA; this.delta = PLSAUDTF.DEFAULT_DELTA; } return cl; } @Override public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 1 || parameters.length == 4 || parameters.length == 5); super.init(mode, parameters); // initialize input if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data processOptions(parameters); this.wordOI = HiveUtils.asStringOI(parameters[0]); this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]); this.labelOI = HiveUtils.asIntegerOI(parameters[2]); this.probOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; this.internalMergeOI = soi; this.wcListField = soi.getStructFieldRef("wcList"); this.probMapField = soi.getStructFieldRef("probMap"); this.topicsOptionField = soi.getStructFieldRef("topics"); this.alphaOptionField = soi.getStructFieldRef("alpha"); this.deltaOptionField = soi.getStructFieldRef("delta"); this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI); this.probMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; this.probMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; this.probMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(probMapValueElemOI); this.probMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(probMapKeyOI, probMapValueOI); } // initialize output final ObjectInspector outputOI; if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial outputOI = internalMergeOI(); } else { final ArrayList<String> fieldNames = new ArrayList<String>(); final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("label"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("probability"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); outputOI = ObjectInspectorFactory.getStandardListObjectInspector( ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)); } return outputOI; } private static StructObjectInspector internalMergeOI() { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("wcList"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector)); fieldNames.add("probMap"); fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); fieldNames.add("topics"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("alpha"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); fieldNames.add("delta"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @SuppressWarnings("deprecation") @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AggregationBuffer myAggr = new PLSAPredictAggregationBuffer(); reset(myAggr); return myAggr; } @Override public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; myAggr.reset(); myAggr.setOptions(topics, alpha, delta); } @Override public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, Object[] parameters) throws HiveException { PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; if (parameters[0] == null || parameters[1] == null || parameters[2] == null || parameters[3] == null) { return; } String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI); float value = PrimitiveObjectInspectorUtils.getFloat(parameters[1], valueOI); int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI); float prob = PrimitiveObjectInspectorUtils.getFloat(parameters[3], probOI); myAggr.iterate(word, value, label, prob); } @Override public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; if (myAggr.wcList.size() == 0) { return null; } Object[] partialResult = new Object[5]; partialResult[0] = myAggr.wcList; partialResult[1] = myAggr.probMap; partialResult[2] = new IntWritable(myAggr.topics); partialResult[3] = new FloatWritable(myAggr.alpha); partialResult[4] = new DoubleWritable(myAggr.delta); return partialResult; } @Override public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField); List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj)); // fix list elements to Java String objects int wcListSize = wcListRaw.size(); List<String> wcList = new ArrayList<String>(); for (int i = 0; i < wcListSize; i++) { wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI)); } Object probMapObj = internalMergeOI.getStructFieldData(partial, probMapField); Map<?, ?> probMapRaw = probMapOI.getMap(HiveUtils.castLazyBinaryObject(probMapObj)); Map<String, List<Float>> probMap = new HashMap<String, List<Float>>(); for (Map.Entry<?, ?> e : probMapRaw.entrySet()) { // fix map keys to Java String objects String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), probMapKeyOI); Object probMapValueObj = e.getValue(); List<?> probMapValueRaw = probMapValueOI.getList(HiveUtils.castLazyBinaryObject(probMapValueObj)); // fix map values to lists of Java Float objects int probMapValueSize = probMapValueRaw.size(); List<Float> prob_word = new ArrayList<Float>(); for (int i = 0; i < probMapValueSize; i++) { prob_word.add(HiveUtils.getFloat(probMapValueRaw.get(i), probMapValueElemOI)); } probMap.put(word, prob_word); } // restore options from partial result Object topicsObj = internalMergeOI.getStructFieldData(partial, topicsOptionField); this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj); Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField); this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj); Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField); this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj); PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; myAggr.setOptions(topics, alpha, delta); myAggr.merge(wcList, probMap); } @SuppressWarnings("unchecked") @Override public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; final float[] topicDistr = myAggr.get(); final KeySortablePair<Float, Integer>[] sorted = new KeySortablePair[topicDistr.length]; for (int i = 0; i < topicDistr.length; i++) { sorted[i] = new KeySortablePair<>(topicDistr[i], i); } Arrays.sort(sorted, Collections.reverseOrder()); final List<Object[]> result = new ArrayList<Object[]>(sorted.length); for (KeySortablePair<Float, Integer> e : sorted) { Object[] struct = new Object[2]; struct[0] = new IntWritable(e.getValue()); // label struct[1] = new FloatWritable(e.getKey()); // probability result.add(struct); } return result; } } public static class PLSAPredictAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer { private List<String> wcList; private Map<String, List<Float>> probMap; private int topics; private float alpha; private double delta; PLSAPredictAggregationBuffer() { super(); } void setOptions(int topics, float alpha, double delta) { this.topics = topics; this.alpha = alpha; this.delta = delta; } void reset() { this.wcList = new ArrayList<String>(); this.probMap = new HashMap<String, List<Float>>(); } void iterate(@Nonnull final String word, final float value, final int label, final float prob) { wcList.add(word + ":" + value); // for an unforeseen word, initialize its probs w/ -1s List<Float> prob_word = probMap.get(word); if (prob_word == null) { prob_word = new ArrayList<Float>(Collections.nCopies(topics, -1.f)); probMap.put(word, prob_word); } // set the given prob value prob_word.set(label, Float.valueOf(prob)); } void merge(@Nonnull final List<String> o_wcList, @Nonnull final Map<String, List<Float>> o_probMap) { wcList.addAll(o_wcList); for (Map.Entry<String, List<Float>> e : o_probMap.entrySet()) { String o_word = e.getKey(); List<Float> o_prob_word = e.getValue(); final List<Float> prob_word = probMap.get(o_word); if (prob_word == null) {// for a partially observed word probMap.put(o_word, o_prob_word); } else { // for an unforeseen word for (int k = 0; k < topics; k++) { final float prob_k = o_prob_word.get(k).floatValue(); if (prob_k != -1.f) { // not default value prob_word.set(k, prob_k); // set the partial prob value } } probMap.put(o_word, prob_word); } } } float[] get() { final IncrementalPLSAModel model = new IncrementalPLSAModel(topics, alpha, delta); for (Map.Entry<String, List<Float>> e : probMap.entrySet()) { final String word = e.getKey(); final List<Float> prob_word = e.getValue(); for (int k = 0; k < topics; k++) { final float prob_k = prob_word.get(k).floatValue(); if (prob_k != -1.f) { model.setWordScore(word, k, prob_k); } } } String[] wcArray = wcList.toArray(new String[wcList.size()]); return model.getTopicDistribution(wcArray); } } }