/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.evaluation; import hivemall.UDAFEvaluatorWithOptions; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Primitives; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import javax.annotation.Nonnull; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @Description(name = "fmeasure", value = "_FUNC_(array|int|boolean actual, array|int| boolean predicted [, const string options])" + " - Return a F-measure (f1score is the special with beta=1.0)") public final class FMeasureUDAF extends AbstractGenericUDAFResolver { @Override public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { if (typeInfo.length != 2 && typeInfo.length != 3) { throw new UDFArgumentTypeException(typeInfo.length - 1, "_FUNC_ takes two or three arguments"); } boolean isArg1ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[0]) || HiveUtils.isIntegerTypeInfo(typeInfo[0]) || HiveUtils.isBooleanTypeInfo(typeInfo[0]); if (!isArg1ListOrIntOrBoolean) { throw new UDFArgumentTypeException(0, "The first argument `array/int/boolean actual` is invalid form: " + typeInfo[0]); } boolean isArg2ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[1]) || HiveUtils.isIntegerTypeInfo(typeInfo[1]) || HiveUtils.isBooleanTypeInfo(typeInfo[1]); if (!isArg2ListOrIntOrBoolean) { throw new UDFArgumentTypeException(1, "The second argument `array/int/boolean predicted` is invalid form: " + typeInfo[1]); } if (!typeInfo[0].equals(typeInfo[1])) { throw new UDFArgumentTypeException(1, "The first argument `actual`'s type is " + typeInfo[0] + ", but the second argument `predicted`'s type is not match: " + typeInfo[1]); } return new Evaluator(); } public static class Evaluator extends UDAFEvaluatorWithOptions { private ObjectInspector actualOI; private ObjectInspector predictedOI; private StructObjectInspector internalMergeOI; private StructField tpField; private StructField totalActualField; private StructField totalPredictedField; private StructField betaOptionField; private StructField averageOptionFiled; private double beta; private String average; public Evaluator() {} @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("beta", true, "The weight of precision [default: 1.]"); opts.addOption("average", true, "The way of average calculation [default: micro]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { CommandLine cl = null; double beta = 1.0d; String average = "micro"; if (argOIs.length >= 3) { String rawArgs = HiveUtils.getConstString(argOIs[2]); cl = parseOptions(rawArgs); beta = Primitives.parseDouble(cl.getOptionValue("beta"), beta); if (beta <= 0.d) { throw new UDFArgumentException( "The third argument `double beta` must be greater than 0.0: " + beta); } average = cl.getOptionValue("average", average); if (average.equals("macro")) { throw new UDFArgumentException("\"-average macro\" is not supported"); } if (!(average.equals("binary") || average.equals("micro"))) { throw new UDFArgumentException( "The third argument `String average` must be one of the {binary, micro, macro}: " + average); } } this.beta = beta; this.average = average; return cl; } @Override public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 2 || parameters.length == 3) : parameters.length; super.init(mode, parameters); // initialize input if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data processOptions(parameters); this.actualOI = parameters[0]; this.predictedOI = parameters[1]; } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; this.internalMergeOI = soi; this.tpField = soi.getStructFieldRef("tp"); this.totalActualField = soi.getStructFieldRef("totalActual"); this.totalPredictedField = soi.getStructFieldRef("totalPredicted"); this.betaOptionField = soi.getStructFieldRef("beta"); this.averageOptionFiled = soi.getStructFieldRef("average"); } // initialize output final ObjectInspector outputOI; if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial outputOI = internalMergeOI(); } else {// terminate outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } return outputOI; } @Nonnull private static StructObjectInspector internalMergeOI() { List<String> fieldNames = new ArrayList<>(); List<ObjectInspector> fieldOIs = new ArrayList<>(); fieldNames.add("tp"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); fieldNames.add("totalActual"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); fieldNames.add("totalPredicted"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); fieldNames.add("beta"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); fieldNames.add("average"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override public FMeasureAggregationBuffer getNewAggregationBuffer() throws HiveException { FMeasureAggregationBuffer myAggr = new FMeasureAggregationBuffer(); reset(myAggr); return myAggr; } @Override public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; myAggr.reset(); myAggr.setOptions(beta, average); } @Override public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, Object[] parameters) throws HiveException { FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; boolean isList = HiveUtils.isListOI(actualOI) && HiveUtils.isListOI(predictedOI); final List<?> actual; final List<?> predicted; if (isList) {// array case if ("binary".equals(average)) { throw new UDFArgumentException( "\"-average binary\" is not supported when `predict` is array"); } actual = ((ListObjectInspector) actualOI).getList(parameters[0]); predicted = ((ListObjectInspector) predictedOI).getList(parameters[1]); } else {//binary case if (HiveUtils.isBooleanOI(actualOI)) { // boolean case actual = Arrays.asList( asIntLabel(parameters[0], (BooleanObjectInspector) actualOI)); predicted = Arrays.asList( asIntLabel(parameters[1], (BooleanObjectInspector) predictedOI)); } else { // integer case final int actualLabel = asIntLabel(parameters[0], HiveUtils.asIntegerOI(actualOI)); if (actualLabel == 0 && "binary".equals(average)) { actual = Collections.emptyList(); } else { actual = Arrays.asList(actualLabel); } final int predictedLabel = asIntLabel(parameters[1], HiveUtils.asIntegerOI(predictedOI)); if (predictedLabel == 0 && "binary".equals(average)) { predicted = Collections.emptyList(); } else { predicted = Arrays.asList(predictedLabel); } } } myAggr.iterate(actual, predicted); } private static int asIntLabel(@Nonnull final Object o, @Nonnull final BooleanObjectInspector booleanOI) { if (booleanOI.get(o)) { return 1; } else { return 0; } } private static int asIntLabel(@Nonnull final Object o, @Nonnull final PrimitiveObjectInspector intOI) throws UDFArgumentException { final int value = PrimitiveObjectInspectorUtils.getInt(o, intOI); switch (value) { case 1: return 1; case 0: case -1: return 0; default: throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value); } } @Override public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; Object[] partialResult = new Object[5]; partialResult[0] = new LongWritable(myAggr.tp); partialResult[1] = new LongWritable(myAggr.totalActual); partialResult[2] = new LongWritable(myAggr.totalPredicted); partialResult[3] = new DoubleWritable(myAggr.beta); partialResult[4] = myAggr.average; return partialResult; } @Override public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } Object tpObj = internalMergeOI.getStructFieldData(partial, tpField); Object totalActualObj = internalMergeOI.getStructFieldData(partial, totalActualField); Object totalPredictedObj = internalMergeOI.getStructFieldData(partial, totalPredictedField); Object betaObj = internalMergeOI.getStructFieldData(partial, betaOptionField); Object averageObj = internalMergeOI.getStructFieldData(partial, averageOptionFiled); long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj); long totalActual = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalActualObj); long totalPredicted = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get( totalPredictedObj); double beta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(betaObj); String average = PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject( averageObj); FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; myAggr.merge(tp, totalActual, totalPredicted, beta, average); } @Override public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; double result = myAggr.get(); return new DoubleWritable(result); } } @AggregationType(estimable = true) public static class FMeasureAggregationBuffer extends AbstractAggregationBuffer { long tp; /** tp + fn */ long totalActual; /** tp + fp */ long totalPredicted; double beta; String average; public FMeasureAggregationBuffer() { super(); } @Override public int estimate() { JavaDataModel model = JavaDataModel.get(); return model.primitive2() * 4 + model.lengthFor(average); } void setOptions(double beta, String average) { this.beta = beta; this.average = average; } void reset() { this.tp = 0L; this.totalActual = 0L; this.totalPredicted = 0L; } void merge(final long o_tp, final long o_actual, final long o_predicted, final double beta, final String average) { tp += o_tp; totalActual += o_actual; totalPredicted += o_predicted; this.beta = beta; this.average = average; } double get() { final double squareBeta = beta * beta; final double divisor; final double numerator; if ("micro".equals(average)) { divisor = denom(tp, totalActual, totalPredicted, squareBeta); numerator = (1.d + squareBeta) * tp; } else { // binary double precision = precision(tp, totalPredicted); double recall = recall(tp, totalActual); divisor = squareBeta * precision + recall; numerator = (1.d + squareBeta) * precision * recall; } if (divisor > 0) { return (numerator / divisor); } else { return 0.d; } } private static double denom(final long tp, final long totalActual, final long totalPredicted, double squareBeta) { long lp = totalActual - tp; long pl = totalPredicted - tp; return squareBeta * (tp + lp) + tp + pl; } private static double precision(final long tp, final long totalPredicted) { return (totalPredicted == 0L) ? 0.d : tp / (double) totalPredicted; } private static double recall(final long tp, final long totalActual) { return (totalActual == 0L) ? 0.d : tp / (double) totalActual; } void iterate(@Nonnull final List<?> actual, @Nonnull final List<?> predicted) { final int numActual = actual.size(); final int numPredicted = predicted.size(); int countTp = 0; for (Object p : predicted) { if (actual.contains(p)) { countTp++; } } this.tp += countTp; this.totalActual += numActual; this.totalPredicted += numPredicted; } } }