Java Code Examples for smile.data.AttributeDataset#toArray()

The following examples show how to use smile.data.AttributeDataset#toArray() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
/**
 * Test of learn method, of class DecisionTree.
 */
@Test
public void testIris() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
        assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
    }
}
 
Example 2
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException {
    String urlString =
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
    DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString);
    DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString);

    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);

    int diff = 0;
    for (int i = 0; i < size; i++) {
        if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
            diff++;
        }
    }

    Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
}
 
Example 3
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testIrisSerializedObj() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

        byte[] b = tree.serialize(false);
        Node node = DecisionTree.deserialize(b, b.length, false);
        assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
    }
}
 
Example 4
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
/**
 * Test of learn method, of class DecisionTree.
 */
@Test
public void testIris() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
        Assert.assertEquals(tree.predict(x[loocv.test[i]]),
            evalPredict(tree, x[loocv.test[i]]));
    }
}
 
Example 5
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int k = 10;

    CrossValidation cv = new CrossValidation(n, k);
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(datax, cv.train[i]);
        double[] trainy = Math.slice(datay, cv.train[i]);
        double[][] testx = Math.slice(datax, cv.test[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
        RegressionTree tree = new RegressionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);

        for (int j = 0; j < testx.length; j++) {
            assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
        }
    }
}
 
Example 6
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    final Object[][] rows = new Object[size][2];
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        final List<String> xi = new ArrayList<String>(x[0].length);
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        rows[i][0] = xi;
        rows[i][1] = y[i];
    }

    TestUtils.testGenericUDTFSerialization(GradientTreeBoostingClassifierUDTF.class,
        new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490")},
        rows);
}
 
Example 7
Source File: SmileTargetClassifierBuilder.java    From ache with Apache License 2.0 5 votes vote down vote up
public static void trainModel(String trainingPath, String outputPath, String learner,
        int responseIndex, boolean skipCrossValidation) throws Exception {

    if (learner == null) {
        learner = "SVM";
    }

    System.out.println("Learning algorithm: " + learner);
    String modelFilePath = Paths.get(outputPath, "pageclassifier.model").toString();

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    Path arffFilePath = Paths.get(trainingPath, "/smile_input.arff");
    FileInputStream fis = new FileInputStream(arffFilePath.toFile());
    System.out.println("Writting temporarily data file to: " + arffFilePath.toString());

    AttributeDataset trainingData = arffParser.parse(fis);
    double[][] x = trainingData.toArray(new double[trainingData.size()][]);
    int[] y = trainingData.toArray(new int[trainingData.size()]);

    SoftClassifier<double[]> finalModel = null;
    if (skipCrossValidation) {
        System.out.println("Starting model training on whole dataset...");
        finalModel = trainClassifierNoCV(learner, x, y);
    } else {
        System.out.println("Starting cross-validation...");
        finalModel = trainModelCV(learner, x, y);
    }
    System.out.println("Writing model to file: " + modelFilePath);
    SmileUtil.writeSmileClassifier(modelFilePath, finalModel);
}
 
Example 8
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testIrisSerializeObjCompressed() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    AttributeDataset iris = arffParser.parse(is);
    double[][] x = iris.toArray(new double[iris.size()][]);
    int[] y = iris.toArray(new int[iris.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

        byte[] b1 = tree.serialize(true);
        byte[] b2 = tree.serialize(false);
        Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length,
            b1.length < b2.length);
        Node node = DecisionTree.deserialize(b1, b1.length, true);
        assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
    }
}
 
Example 9
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
        throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    int n = x.length;
    LOOCV loocv = new LOOCV(n);
    int error = 0;
    for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(i));
        if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
            error++;
        }
    }

    debugPrint("Decision Tree error = " + error);
    return error;
}
 
Example 10
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 11
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu2() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));

    for (int i = m; i < n; i++) {
        Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
    }
}
 
Example 12
Source File: TreePredictUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testCpu() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int k = 10;

    CrossValidation cv = new CrossValidation(n, k);
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(datax, cv.train[i]);
        double[] trainy = Math.slice(datay, cv.train[i]);
        double[][] testx = Math.slice(datax, cv.test[i]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
        RegressionTree tree = new RegressionTree(attrs,
            new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);

        for (int j = 0; j < testx.length; j++) {
            Assert.assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
        }
    }
}
 
Example 13
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs)
        throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    final Attribute[] attrs = ds.attributes();
    final Attribute targetAttr = ds.response();

    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    Random rnd = new Random(43L);
    int numTrain = (int) (x.length * 0.7);
    int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
    int[] cvTrain = Arrays.copyOf(index, numTrain);
    int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);

    double[][] trainx = Math.slice(x, cvTrain);
    int[] trainy = Math.slice(y, cvTrain);
    double[][] testx = Math.slice(x, cvTest);

    DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
        matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L));

    final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
    final StringBuilder buf = new StringBuilder();
    for (int i = 0; i < testx.length; i++) {
        final DenseVector test = new DenseVector(testx[i]);
        tree.predict(test, new PredictionHandler() {

            @Override
            public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                    double splitValue) {
                buf.append(attrs[splitFeatureIndex].name);
                buf.append(" [" + splitFeature + "] ");
                buf.append(op);
                buf.append(' ');
                buf.append(splitValue);
                buf.append('\n');

                map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op,
                    splitValue);
            }

            @Override
            public void visitLeaf(int output, double[] posteriori) {
                buf.append(targetAttr.toString(output));
            }
        });

        Assert.assertTrue(buf.length() > 0);
        Assert.assertFalse(map.isEmpty());

        StringUtils.clear(buf);
        map.clear();
    }

}
 
Example 14
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
Example 15
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testTitanicPruning() throws IOException, ParseException {
    String datasetUrl =
            "https://gist.githubusercontent.com/myui/7cd82c443db84ba7e7add1523d0247a9/raw/f2d3e3051b0292577e8c01a1759edabaa95c5781/titanic_train.tsv";

    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setColumnNames(true);
    parser.setDelimiter(",");
    parser.setResponseIndex(new NominalAttribute("survived"), 0);

    AttributeDataset train = parser.parse("titanic train", is);
    double[][] x_ = train.toArray(new double[train.size()][]);
    int[] y = train.toArray(new int[train.size()]);

    // pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked
    // C,C,C,Q,Q,Q,C,Q,C,C
    RoaringBitmap nominalAttrs = new RoaringBitmap();
    nominalAttrs.add(0);
    nominalAttrs.add(1);
    nominalAttrs.add(2);
    nominalAttrs.add(6);
    nominalAttrs.add(8);
    nominalAttrs.add(9);

    int columns = x_[0].length;
    Matrix x = new RowMajorDenseMatrix2d(x_, columns);
    int numVars = (int) Math.ceil(Math.sqrt(columns));
    int maxDepth = Integer.MAX_VALUE;
    int maxLeafs = Integer.MAX_VALUE;
    int minSplits = 2;
    int minLeafSize = 1;
    int[] samples = null;
    PRNG rand = RandomNumberGeneratorFactory.createPRNG(43L);

    final String[] featureNames = new String[] {"pclass", "name", "sex", "age", "sibsp",
            "parch", "ticket", "fare", "cabin", "embarked"};
    final String[] classNames = new String[] {"yes", "no"};
    DecisionTree tree = new DecisionTree(nominalAttrs, x, y, numVars, maxDepth, maxLeafs,
        minSplits, minLeafSize, samples, SplitRule.GINI, rand) {
        @Override
        public String toString() {
            return predictJsCodegen(featureNames, classNames);
        }
    };
    tree.toString();
}
 
Example 16
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 17
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example 18
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test(expected = HiveException.class)
public void testIrisDenseAllNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, null);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.fail("should not be called");
}
 
Example 19
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDenseSomeNullFeaturesTest()
        throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final Random rand = new Random(43);
    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            if (rand.nextDouble() >= 0.7) {
                xi.add(j, null);
            } else {
                xi.add(j, x[i][j]);
            }
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example 20
Source File: TreePredictUDFv1Test.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSerialization() throws HiveException, IOException, ParseException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(6);
    AttributeDataset data = arffParser.parse(is);
    double[] datay = data.toArray(new double[data.size()]);
    double[][] datax = data.toArray(new double[data.size()][]);

    int n = datax.length;
    int m = 3 * n / 4;
    int[] index = Math.permutate(n);

    double[][] trainx = new double[m][];
    double[] trainy = new double[m];
    for (int i = 0; i < m; i++) {
        trainx[i] = datax[index[i]];
        trainy[i] = datay[index[i]];
    }

    double[][] testx = new double[n - m][];
    double[] testy = new double[n - m];
    for (int i = m; i < n; i++) {
        testx[i - m] = datax[index[i]];
        testy[i - m] = datay[index[i]];
    }

    RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
    RegressionTree tree = new RegressionTree(attrs,
        new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
    String opScript = tree.predictOpCodegen(StackMachine.SEP);

    TestUtils.testGenericUDFSerialization(TreePredictUDFv1.class,
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                ObjectInspectorUtils.getConstantObjectInspector(
                    PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)},
        new Object[] {"model_id#1", ModelType.opscode.getId(), opScript,
                ArrayUtils.toList(testx[0])});
}