/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.xgboost; import biz.k11i.xgboost.Predictor; import biz.k11i.xgboost.util.FVec; import hivemall.TestBase; import hivemall.TestUtils; import hivemall.utils.lang.PrivilegedAccessor; import hivemall.utils.lang.mutable.MutableObject; import hivemall.utils.math.MathUtils; import hivemall.xgboost.utils.XGBoostUtils; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoostError; import java.util.Arrays; import java.util.List; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; import org.junit.experimental.theories.DataPoints; import org.junit.experimental.theories.FromDataPoints; import org.junit.experimental.theories.Theories; import org.junit.experimental.theories.Theory; import org.junit.runner.RunWith; @RunWith(Theories.class) public class XGBoostTrainUDTFTest extends TestBase { @Test public void testSerialization() throws HiveException { TestUtils.testGenericUDTFSerialization(XGBoostTrainUDTF.class, new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-objective reg:linear")}, new Object[][] {{Arrays.asList("1:-2", "2:-1"), 0.d}}); } @DataPoints("trial") public static final List<TestParameter> trial = TestParameter.merge( // 1. binary classification // mashroom dataset new TestParameter.Builder().trainDataset(new LibsvmDataset( "https://raw.githubusercontent.com/dmlc/xgboost/master/demo/data/agaricus.txt.train")) .testDatase(new LibsvmDataset( "https://raw.githubusercontent.com/dmlc/xgboost/master/demo/data/agaricus.txt.test")) .metric(new ClassificationError(0.1f)) .hyperParams(new String[] { "-objective binary:logistic -iters 10", "-objective binary:logistic -iters 10 -num_early_stopping_rounds 3"}), // 2. multiclass classification // https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data new TestParameter.Builder().trainDataset(new DermatologyDataset(true, 0.7f)) .testDatase(new DermatologyDataset(false, 0.7f)) .metric(new MultiClassClassificationError(0.1f)) .hyperParams(new String[] { "-objective multi:softmax -num_class 6 -max_depth 6 -eta 0.1 -num_round 5"}), new TestParameter.Builder().trainDataset(new DermatologyDataset(true, 0.7f)) .testDatase(new DermatologyDataset(false, 0.7f)) .metric(new MultiClassClassificationError(0.1f)) .hyperParams(new String[] { "-objective multi:softprob -num_class 6 -max_depth 6 -eta 0.1 -num_round 5"}), // 3. regression // https://archive.ics.uci.edu/ml/datasets/Computer+Hardware // https://github.com/dmlc/xgboost/blob/master/demo/regression/machine.conf new TestParameter.Builder().trainDataset(new LibsvmDataset( "https://raw.githubusercontent.com/myui/ml_dataset/master/regr/computer_hardware/machine.txt.train")) .testDatase(new LibsvmDataset( "https://raw.githubusercontent.com/myui/ml_dataset/master/regr/computer_hardware/machine.txt.test")) .metric(new MAE(40f)) .hyperParams(new String[] { "-booster gbtree -objective reg:linear -eta 1.0 -gamma 1.0 -min_child_weight 1.0 -max_depth 3 -num_round 5"})); @Theory public void testHyperParams(@FromDataPoints("trial") final TestParameter trial) throws Exception { final Dataset trainDataset = trial.trainDataset; final Dataset testDataset = trial.testDataset; final DMatrix testMatrix = testDataset.loadDatasetAsDMatrix(); final float[] testLabels = testMatrix.getLabel(); final EvalMetric metric = trial.evalMetric; final MutableObject<float[][]> expectedPredictData = new MutableObject<>(); final XGBoostTrainUDTF udtf = new XGBoostTrainUDTF() { @Override protected void onFinishTraining(Booster booster) { final float[][] result; try { result = booster.predict(testMatrix); } catch (XGBoostError e) { throw new RuntimeException(e); } expectedPredictData.set(result); } }; if (trainDataset.isSparseDataset()) { udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, trial.hyperParams)}); } else { udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaFloatObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, trial.hyperParams)}); } for (Object[] row : trainDataset.loadDatasetAsListOfObjects()) { udtf.process(row); } udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException { final float[][] expecteds = expectedPredictData.get(); Object[] forwardedObj = (Object[]) input; String modelId = (String) forwardedObj[0]; Assert.assertNotNull(modelId); Text modelStr = (Text) forwardedObj[1]; Booster booster = XGBoostUtils.deserializeBooster(modelStr); try { float[][] actuals = booster.predict(testMatrix); Assert.assertEquals(expecteds.length, actuals.length); for (int i = 0; i > expecteds.length; i++) { Assert.assertArrayEquals(expecteds[i], actuals[i], 1e-5f); } } catch (XGBoostError e) { throw new HiveException(e); } finally { XGBoostUtils.close(booster); } Predictor predictor = XGBoostUtils.loadPredictor(modelStr); final String gbmName, objName; try { gbmName = (String) PrivilegedAccessor.getValue(predictor, "name_gbm"); objName = (String) PrivilegedAccessor.getValue(predictor, "name_obj"); } catch (Exception e) { throw new HiveException(e); } Assert.assertEquals(udtf.params.get("booster"), gbmName); Assert.assertEquals(udtf.params.get("objective"), objName); final List<FVec> fvList; try { fvList = testDataset.loadDatasetAsListOfFVec(); } catch (Exception e) { throw new HiveException(e); } Assert.assertEquals(expecteds.length, fvList.size()); int mismatches = 0; for (int i = 0; i < expecteds.length; i++) { float[] expected = expecteds[i]; FVec fv = fvList.get(i); double[] actual = predictor.predict(fv); Assert.assertEquals(expected.length, actual.length); if (!objName.startsWith("reg:")) { for (int j = 0; j < expected.length; j++) { if (!MathUtils.equals(expected[j], actual[j], 1e-5d)) { mismatches++; break; } } } metric.next(actual, testLabels[i]); } Assert.assertTrue( "Too many mismatches in prediction result between xgboost4j and xgboost-predictor: " + mismatches, mismatches <= 2); } }); udtf.close(); testMatrix.dispose(); metric.assertExpected(); } @Test(expected = UDFArgumentException.class) public void testNoObjective() throws HiveException { XGBoostTrainUDTF udtf = new XGBoostTrainUDTF(); udtf.initialize( new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaFloatObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-num_class 4")}); } //--------------------------------------------------- // multiclass target value tests @Test public void testCheckTargetValueSucess() throws HiveException { XGBoostTrainUDTF udtf = new XGBoostTrainUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaFloatObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-objective multi:softmax -num_class 4")}); udtf.processTargetValue(1.0f); udtf.processTargetValue(3f); } @Test(expected = UDFArgumentException.class) public void testCheckInvalidTargetValue1() throws HiveException { XGBoostTrainUDTF udtf = new XGBoostTrainUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaFloatObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-objective multi:softmax")}); udtf.processTargetValue(1.1f); Assert.fail("-num_class option is missing"); } @Test(expected = UDFArgumentException.class) public void testCheckInvalidTargetValue2() throws HiveException { XGBoostTrainUDTF udtf = new XGBoostTrainUDTF(); udtf.processOptions(new ObjectInspector[] {null, null, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-objective multi:softmax -num_class 3")}); udtf.processTargetValue(-2f); Assert.fail(); } @Test(expected = UDFArgumentException.class) public void testCheckInvalidTargetValue3() throws HiveException { XGBoostTrainUDTF udtf = new XGBoostTrainUDTF(); udtf.processOptions(new ObjectInspector[] {null, null, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-objective multi:softmax -num_class 3")}); udtf.processTargetValue(3f); Assert.fail(); } }