/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.xgboost.utils; import biz.k11i.xgboost.Predictor; import biz.k11i.xgboost.util.FVec; import hivemall.utils.io.FastByteArrayInputStream; import hivemall.utils.io.IOUtils; import hivemall.xgboost.XGBoostBatchPredictUDTF.LabeledPointWithRowId; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.io.Text; public final class XGBoostUtils { private XGBoostUtils() {} @Nonnull public static String getVersion() throws HiveException { Properties props = new Properties(); try (InputStream versionResourceFile = Thread.currentThread().getContextClassLoader().getResourceAsStream( "xgboost4j-version.properties")) { props.load(versionResourceFile); } catch (IOException e) { throw new HiveException("Failed to load xgboost4j-version.properties", e); } return props.getProperty("version", "<unknown>"); } @Nonnull public static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data) throws XGBoostError { final List<LabeledPoint> points = new ArrayList<>(data.size()); for (LabeledPointWithRowId d : data) { points.add(d); } return new DMatrix(points.iterator(), ""); } @Nonnull public static Booster createBooster(@Nonnull DMatrix matrix, @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError, IllegalAccessException, InvocationTargetException, InstantiationException { Class<?>[] args = {Map.class, DMatrix[].class}; Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args); ctor.setAccessible(true); return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}}); } public static void close(@Nullable final DMatrix matrix) { if (matrix == null) { return; } try { matrix.dispose(); } catch (Throwable e) { ; } } public static void close(@Nullable final Booster booster) { if (booster == null) { return; } try { booster.dispose(); } catch (Throwable e) { ; } } @Nonnull public static Text serializeBooster(@Nonnull final Booster booster) throws HiveException { try { byte[] b = IOUtils.toCompressedText(booster.toByteArray()); return new Text(b); } catch (Throwable e) { throw new HiveException("Failed to serialize a booster", e); } } @Nonnull public static Booster deserializeBooster(@Nonnull final Text model) throws HiveException { try { byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength()); return XGBoost.loadModel(new FastByteArrayInputStream(b)); } catch (Throwable e) { throw new HiveException("Failed to deserialize a booster", e); } } @Nonnull public static Predictor loadPredictor(@Nonnull final Text model) throws HiveException { try { byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength()); return new Predictor(new FastByteArrayInputStream(b)); } catch (Throwable e) { throw new HiveException("Failed to create a predictor", e); } } @Nonnull public static FVec parseRowAsFVec(@Nonnull final String[] row, final int start, final int end) { final Map<Integer, Float> map = new HashMap<>((int) (row.length * 1.5)); for (int i = start; i < end; i++) { String f = row[i]; if (f == null) { continue; } String str = f.toString(); final int pos = str.indexOf(':'); if (pos < 1) { throw new IllegalArgumentException("Invalid feature format: " + str); } final int index; final float value; try { index = Integer.parseInt(str.substring(0, pos)); value = Float.parseFloat(str.substring(pos + 1)); } catch (NumberFormatException e) { throw new IllegalArgumentException("Failed to parse a feature value: " + str); } map.put(index, value); } return FVec.Transformer.fromMap(map); } }