Java Code Examples for smile.data.parser.ArffParser#setResponseIndex()

The following examples show how to use smile.data.parser.ArffParser#setResponseIndex() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
/**
 * Test of learn method, of class DecisionTree.
 */
@Test
public void testIris() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
        assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
    }
}
 
Example 2
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException {
    String urlString =
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
    DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString);
    DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString);

    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);

    int diff = 0;
    for (int i = 0; i < size; i++) {
        if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
            diff++;
        }
    }

    Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
}
 
Example 3
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testIrisSerializedObj() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

        byte[] b = tree.serialize(false);
        Node node = DecisionTree.deserialize(b, b.length, false);
        assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
    }
}
 
Example 4
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
private static String graphvizOutput(String datasetUrl, int responseIndex, int numLeafs,
        boolean dense, String[] featureNames, String[] classNames, String outputName)
        throws IOException, HiveException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
    DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs,
        RandomNumberGeneratorFactory.createPRNG(31));

    Text model = new Text(Base91.encode(tree.serialize(true)));

    Evaluator eval = new Evaluator(OutputType.graphviz, outputName, false);
    Text exported = eval.export(model, featureNames, classNames);

    return exported.toString();
}
 
Example 5
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex,
        int numLeafs) throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
        DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
        Assert.assertEquals(dtree.toString(), stree.toString());
    }
}
 
Example 6
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 7
Source File: Test.java    From java_in_examples with Apache License 2.0 5 votes vote down vote up
private void test() throws IOException, ParseException {
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset weather = arffParser.parse(this.getClass().getResourceAsStream("/smile/data/weka/weather.nominal.arff"));
    double[][] x = weather.toArray(new double[weather.size()][]);
    int[] y = weather.toArray(new int[weather.size()]);
}
 
Example 8
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 9
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testIrisSerializeObjCompressed() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

        byte[] b1 = tree.serialize(true);
        byte[] b2 = tree.serialize(false);
        Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length,
            b1.length < b2.length);
        Node node = DecisionTree.deserialize(b1, b1.length, true);
        assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
    }
}
 
Example 10
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
        throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    int error = 0;
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
            error++;
        }
    }

    debugPrint("Decision Tree error = " + error);
    return error;
}
 
Example 11
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    final Object[][] rows = new Object[size][2];
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        final List<String> xi = new ArrayList<String>(x[0].length);
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        rows[i][0] = xi;
        rows[i][1] = y[i];
    }

    TestUtils.testGenericUDTFSerialization(RandomForestClassifierUDTF.class,
        new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49")},
        rows);
}
 
Example 12
Source File: SmileTargetClassifierBuilder.java    From ache with Apache License 2.0 5 votes vote down vote up
public static void trainModel(String trainingPath, String outputPath, String learner,
        int responseIndex, boolean skipCrossValidation) throws Exception {

    if (learner == null) {
        learner = "SVM";
    }

    System.out.println("Learning algorithm: " + learner);
    String modelFilePath = Paths.get(outputPath, "pageclassifier.model").toString();

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    Path arffFilePath = Paths.get(trainingPath, "/smile_input.arff");
    FileInputStream fis = new FileInputStream(arffFilePath.toFile());
    System.out.println("Writting temporarily data file to: " + arffFilePath.toString());

    AttributeDataset trainingData = arffParser.parse(fis);
    double[][] x = trainingData.toArray(new double[trainingData.size()][]);
    int[] y = trainingData.toArray(new int[trainingData.size()]);

    SoftClassifier<double[]> finalModel = null;
    if (skipCrossValidation) {
        System.out.println("Starting model training on whole dataset...");
        finalModel = trainClassifierNoCV(learner, x, y);
    } else {
        System.out.println("Starting cross-validation...");
        finalModel = trainModelCV(learner, x, y);
    }
    System.out.println("Writing model to file: " + modelFilePath);
    SmileUtil.writeSmileClassifier(modelFilePath, finalModel);
}
 
Example 13
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 14
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    String opScript = tree.predictOpCodegen(StackMachine.SEP);

    TestUtils.testGenericUDFSerialization(TreePredictUDFv1.class,
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)},
        new Object[] {"model_id#1", ModelType.opscode.getId(), opScript,
                ArrayUtils.toList(testx[0])});
}
 
Example 15
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example 16
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example 17
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);

    byte[] b = tree.serialize(true);
    byte[] encoded = Base91.encode(b);
    Text model = new Text(encoded);

    TestUtils.testGenericUDFSerialization(TreePredictUDF.class,
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.writableStringObjectInspector,
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)},
        new Object[] {"model_id#1", model, ArrayUtils.toList(testx[0])});
}
 
Example 18
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 19
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test(expected = HiveException.class)
public void testIrisDenseAllNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, null);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.fail("should not be called");
}
 
Example 20
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDenseSomeNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final Random rand = new Random(43);
    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            if (rand.nextDouble() >= 0.7) {
                xi.add(j, null);
            } else {
                xi.add(j, x[i][j]);
            }
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}