Java Code Examples for smile.data.parser.ArffParser#parse()

The following examples show how to use smile.data.parser.ArffParser#parse() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException {
    String urlString =
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
    DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString);
    DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString);

    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);

    int diff = 0;
    for (int i = 0; i < size; i++) {
        if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
            diff++;
        }
    }

    Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
}
 
Example 2
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
/**
 * Test of learn method, of class DecisionTree.
 */
@Test
public void testIris() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
        Assert.assertEquals(tree.predict(x[loocv.test[i]]),
            evalPredict(tree, x[loocv.test[i]]));
    }
}
 
Example 3
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex,
        int numLeafs) throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
        DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
        Assert.assertEquals(dtree.toString(), stree.toString());
    }
}
 
Example 4
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 5
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int k = 10;

    CrossValidation cv = new CrossValidation(n, k);
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(datax, cv.train[i]);
        double[] trainy = Math.slice(datay, cv.train[i]);
        double[][] testx = Math.slice(datax, cv.test[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
        RegressionTree tree = new RegressionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);

        for (int j = 0; j < testx.length; j++) {
            assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
        }
    }
}
 
Example 6
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    final Object[][] rows = new Object[size][2];
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        final List<String> xi = new ArrayList<String>(x[0].length);
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        rows[i][0] = xi;
        rows[i][1] = y[i];
    }

    TestUtils.testGenericUDTFSerialization(GradientTreeBoostingClassifierUDTF.class,
        new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490")},
        rows);
}
 
Example 7
Source File: Test.java    From java_in_examples with Apache License 2.0 5 votes vote down vote up
private void test() throws IOException, ParseException {
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset weather = arffParser.parse(this.getClass().getResourceAsStream("/smile/data/weka/weather.nominal.arff"));
    double[][] x = weather.toArray(new double[weather.size()][]);
    int[] y = weather.toArray(new int[weather.size()]);
}
 
Example 8
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 9
Source File: SmileTargetClassifierBuilder.java    From ache with Apache License 2.0 5 votes vote down vote up
public static void trainModel(String trainingPath, String outputPath, String learner,
        int responseIndex, boolean skipCrossValidation) throws Exception {

    if (learner == null) {
        learner = "SVM";
    }

    System.out.println("Learning algorithm: " + learner);
    String modelFilePath = Paths.get(outputPath, "pageclassifier.model").toString();

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    Path arffFilePath = Paths.get(trainingPath, "/smile_input.arff");
    FileInputStream fis = new FileInputStream(arffFilePath.toFile());
    System.out.println("Writting temporarily data file to: " + arffFilePath.toString());

    AttributeDataset trainingData = arffParser.parse(fis);
    double[][] x = trainingData.toArray(new double[trainingData.size()][]);
    int[] y = trainingData.toArray(new int[trainingData.size()]);

    SoftClassifier<double[]> finalModel = null;
    if (skipCrossValidation) {
        System.out.println("Starting model training on whole dataset...");
        finalModel = trainClassifierNoCV(learner, x, y);
    } else {
        System.out.println("Starting cross-validation...");
        finalModel = trainModelCV(learner, x, y);
    }
    System.out.println("Writing model to file: " + modelFilePath);
    SmileUtil.writeSmileClassifier(modelFilePath, finalModel);
}
 
Example 10
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);

    byte[] b = tree.serialize(true);
    byte[] encoded = Base91.encode(b);
    Text model = new Text(encoded);

    TestUtils.testGenericUDFSerialization(TreePredictUDF.class,
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.writableStringObjectInspector,
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)},
        new Object[] {"model_id#1", model, ArrayUtils.toList(testx[0])});
}
 
Example 11
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 12
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs)
        throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    final Attribute[] attrs = ds.attributes();
    final Attribute targetAttr = ds.response();

    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    Random rnd = new Random(43L);
    int numTrain = (int) (x.length * 0.7);
    int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
    int[] cvTrain = Arrays.copyOf(index, numTrain);
    int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);

    double[][] trainx = Math.slice(x, cvTrain);
    int[] trainy = Math.slice(y, cvTrain);
    double[][] testx = Math.slice(x, cvTest);

    DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
        matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L));

    final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
    final StringBuilder buf = new StringBuilder();
    for (int i = 0; i < testx.length; i++) {
        final DenseVector test = new DenseVector(testx[i]);
        tree.predict(test, new PredictionHandler() {

            @Override
            public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                    double splitValue) {
                buf.append(attrs[splitFeatureIndex].name);
                buf.append(" [" + splitFeature + "] ");
                buf.append(op);
                buf.append(' ');
                buf.append(splitValue);
                buf.append('\n');

                map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op,
                    splitValue);
            }

            @Override
            public void visitLeaf(int output, double[] posteriori) {
                buf.append(targetAttr.toString(output));
            }
        });

        Assert.assertTrue(buf.length() > 0);
        Assert.assertFalse(map.isEmpty());

        StringUtils.clear(buf);
        map.clear();
    }

}
 
Example 13
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        final double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
Example 14
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
Example 15
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example 16
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example 17
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 18
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    String opScript = tree.predictOpCodegen(StackMachine.SEP);

    TestUtils.testGenericUDFSerialization(TreePredictUDFv1.class,
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)},
        new Object[] {"model_id#1", ModelType.opscode.getId(), opScript,
                ArrayUtils.toList(testx[0])});
}
 
Example 19
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test(expected = HiveException.class)
public void testIrisDenseAllNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, null);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.fail("should not be called");
}
 
Example 20
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDenseSomeNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final Random rand = new Random(43);
    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            if (rand.nextDouble() >= 0.7) {
                xi.add(j, null);
            } else {
                xi.add(j, x[i][j]);
            }
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}