/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.evaluation; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import java.util.Arrays; import java.util.List; public class FMeasureUDAFTest { FMeasureUDAF fmeasure; GenericUDAFEvaluator evaluator; ObjectInspector[] inputOIs; FMeasureUDAF.FMeasureAggregationBuffer agg; @Before public void setUp() throws Exception { fmeasure = new FMeasureUDAF(); inputOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.writableLongObjectInspector), ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.writableLongObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta 1.")}; evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); } private void setUpWithArguments(double beta, String average) throws Exception { fmeasure = new FMeasureUDAF(); inputOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.writableLongObjectInspector), ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.writableLongObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta + " -average " + average)}; evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); } private void binarySetUp(Object actual, Object predicted, double beta, String average) throws Exception { fmeasure = new FMeasureUDAF(); inputOIs = new ObjectInspector[3]; String actualClassName = actual.getClass().getName(); if (actualClassName.equals("java.lang.Integer")) { inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.INT); } else if (actualClassName.equals("java.lang.Boolean")) { inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN); } else if ((actualClassName.equals("java.lang.String"))) { inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.STRING); } String predicatedClassName = predicted.getClass().getName(); if (predicatedClassName.equals("java.lang.Integer")) { inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.INT); } else if (predicatedClassName.equals("java.lang.Boolean")) { inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN); } else if ((predicatedClassName.equals("java.lang.String"))) { inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.STRING); } inputOIs[2] = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta + " -average " + average); evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); } @Test public void testBinaryMultiSamplesAverageBinary() throws Exception { final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; double beta = 1.; String average = "binary"; binarySetUp(actual[0], predicted[0], beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } // should equal to turi's result // https://turi.com/learn/userguide/evaluation/classification.html#fscores-f1-fbeta- Assert.assertEquals(0.3333d, agg.get(), 1e-4); } @Test(expected = HiveException.class) public void testBinaryMultiSamplesAverageMacro() throws Exception { final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; double beta = 1.; String average = "macro"; binarySetUp(actual[0], predicted[0], beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } agg.get(); } @Test public void testBinaryMultiSamples() throws Exception { final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; double beta = 1.; String average = "micro"; binarySetUp(actual[0], predicted[0], beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } Assert.assertEquals(0.5d, agg.get(), 1e-4); } @Test public void testBinaryMultiSamplesBeta2() throws Exception { final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; double beta = 2.0; String average = "binary"; binarySetUp(actual[0], predicted[0], beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } Assert.assertEquals(0.4166d, agg.get(), 1e-4); } @Test public void testBinary() throws Exception { int actual = 1; int predicted = 1; double beta = 1.0; String average = "micro"; binarySetUp(actual, predicted, beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); Assert.assertEquals(1.d, agg.get(), 1e-4); } @Test public void testBinaryNegativeInput() throws Exception { int actual = 1; int predicted = -1; binarySetUp(actual, predicted, 1.0, "binary"); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); Assert.assertEquals(0.d, agg.get(), 1e-4); } @Test public void testBinaryBooleanInput() throws Exception { boolean actual = true; boolean predicted = false; double beta = 1.0d; binarySetUp(actual, predicted, beta, "binary"); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); Assert.assertEquals(0.d, agg.get(), 1e-4); } @Test(expected = HiveException.class) public void testBinaryInvalidStringInput() throws Exception { String actual = "cat"; int predicted = 1; binarySetUp(actual, predicted, 1.0, "micro"); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); agg.get(); } @Test(expected = HiveException.class) public void testBinaryInvalidLargeIntInput() throws Exception { int actual = 1; int predicted = 3; binarySetUp(actual, predicted, 1.0, "micro"); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); agg.get(); } @Test(expected = HiveException.class) public void testMultiLabelZeroBeta() throws Exception { List<Integer> actual = Arrays.asList(1, 3, 2, 6); List<Integer> predicted = Arrays.asList(1, 2, 4); double beta = 0.; setUpWithArguments(beta, "micro"); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); // FMeasure for beta has zero value is not defined agg.get(); } @Test(expected = HiveException.class) public void testMultiLabelNegativeBeta() throws Exception { List<Integer> actual = Arrays.asList(1, 3, 2, 6); List<Integer> predicted = Arrays.asList(1, 2, 4); double beta = -1.0d; String average = "micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); // FMeasure for beta has negative value is not defined agg.get(); } @Test public void testMultiLabelF1score() throws Exception { List<Integer> actual = Arrays.asList(1, 3, 2, 6); List<Integer> predicted = Arrays.asList(1, 2, 4); double beta = 1.0; String average = " micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); // should equal to spark's micro f1 measure result // https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#multilabel-classification Assert.assertEquals(0.5714285714285714, agg.get(), 1e-5); } @Test public void testMultiLabelMaxFMeasure() throws Exception { List<Integer> actual = Arrays.asList(1, 2, 3); List<Integer> predicted = Arrays.asList(1, 2, 3); double beta = 1.0; String average = "micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); Assert.assertEquals(1.d, agg.get(), 1e-5); } @Test public void testMultiLabelMinFMeasure() throws Exception { List<Integer> actual = Arrays.asList(0, 0, 0); List<Integer> predicted = Arrays.asList(1, 2, 3); double beta = 1.0; String average = "micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {actual, predicted}); Assert.assertEquals(0.d, agg.get(), 1e-5); } @Test public void testMultiLabelF1MultiSamples() throws Exception { String[][] actual = {{"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, {"1", "2"}}; String[][] predicted = {{"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, {"1"}}; double beta = 1.0; String average = "micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } // should equal to spark's micro f1 measure result // https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#multilabel-classification Assert.assertEquals(0.6956d, agg.get(), 1e-4); } @Test public void testMultiLabelFmeasureMultiSamples() throws Exception { String[][] actual = {{"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, {"1", "2"}}; String[][] predicted = {{"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, {"1"}}; double beta = 2.0; String average = "micro"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } Assert.assertEquals(0.6779d, agg.get(), 1e-4); } @Test(expected = HiveException.class) public void testMultiLabelFmeasureBinary() throws Exception { String[][] actual = {{"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, {"1", "2"}}; String[][] predicted = {{"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, {"1"}}; double beta = 1.0; String average = "binary"; setUpWithArguments(beta, average); evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); for (int i = 0; i < actual.length; i++) { evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); } agg.get(); } }