org.apache.hadoop.hive.ql.udf.generic.Collector Java Examples

The following examples show how to use org.apache.hadoop.hive.ql.udf.generic.Collector. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: PassiveAggressiveRegressionUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testPA1() throws HiveException {
    PassiveAggressiveRegressionUDTF udtf = new PassiveAggressiveRegressionUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaFloatObjectInspector});
    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    udtf.process(new Object[] {Arrays.asList("1:-2", "2:-1"), 1.1f});
    udtf.process(new Object[] {Arrays.asList("3:-2", "1:-1"), -1.3f});

    byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, PassiveAggressiveRegressionUDTF.class);

    udtf.close();
}
 
Example #2
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testTwoIntArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new IntWritable(3)});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
    Assert.assertEquals(expected, actual);
}
 
Example #3
Source File: LDAUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testSingleRow() throws HiveException {
    LDAUDTF udtf = new LDAUDTF();
    final int numTopics = 2;
    ObjectInspector[] argOIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                "-topics " + numTopics)};
    udtf.initialize(argOIs);

    String[] doc1 = new String[] {"1", "2", "3"};
    udtf.process(new Object[] {Arrays.asList(doc1)});

    final MutableInt cnt = new MutableInt(0);
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object arg0) throws HiveException {
            cnt.addValue(1);
        }
    });
    udtf.close();

    Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
 
Example #4
Source File: PLSAUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testSingleRow() throws HiveException {
    PLSAUDTF udtf = new PLSAUDTF();
    final int numTopics = 2;
    ObjectInspector[] argOIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                "-topics " + numTopics)};
    udtf.initialize(argOIs);

    String[] doc1 = new String[] {"1", "2", "3"};
    udtf.process(new Object[] {Arrays.asList(doc1)});

    final MutableInt cnt = new MutableInt(0);
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object arg0) throws HiveException {
            cnt.addValue(1);
        }
    });
    udtf.close();

    Assert.assertEquals(doc1.length * numTopics, cnt.getValue());
}
 
Example #5
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testSerialization() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector});

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {}
    });

    udtf.process(new Object[] {1, new IntWritable(3)});

    byte[] serialized = TestUtils.serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, GenerateSeriesUDTF.class);
}
 
Example #6
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testNegativeStepLong() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {5L, new IntWritable(1), -2});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(5), new LongWritable(3), new LongWritable(1));
    Assert.assertEquals(expected, actual);
}
 
Example #7
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testNegativeStepInt() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {5, new IntWritable(1), -2L});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(5), new IntWritable(3), new IntWritable(1));
    Assert.assertEquals(expected, actual);
}
 
Example #8
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testThreeLongArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector,
                PrimitiveObjectInspectorFactory.writableLongObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1L, new LongWritable(7), 3L});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(1), new LongWritable(4), new LongWritable(7));
    Assert.assertEquals(expected, actual);
}
 
Example #9
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testThreeIntArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.javaLongObjectInspector});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new IntWritable(7), 3L});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(4), new IntWritable(7));
    Assert.assertEquals(expected, actual);
}
 
Example #10
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testTwoLongArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(
        new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableLongObjectInspector});

    final List<LongWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            LongWritable row0 = (LongWritable) row[0];
            actual.add(new LongWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {1, new LongWritable(3)});

    List<LongWritable> expected =
            Arrays.asList(new LongWritable(1), new LongWritable(2), new LongWritable(3));
    Assert.assertEquals(expected, actual);
}
 
Example #11
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testSparseRandomForestClassifier() throws HiveException {
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    udtf.process(new Object[] {new String[] {"1:1.0", "4:1.0", "7:1.0", "12:1.0"}, 1}); // 0
    udtf.process(new Object[] {new String[] {"2:1.0", "4:1.0", "5:1.0", "11:1.0"}, 1}); // 1
    udtf.process(new Object[] {
            new String[] {"1:1.0", "4:1.0", "7:1.0", "113:1.0", "497:1.0", "635:1.0"}, 0}); // 2
    udtf.process(new Object[] {
            new String[] {"1:1.0", "4:1.0", "5:1.0", "7:1.0", "10:1.0", "14:1.0"}, 1}); // 3
    udtf.process(new Object[] {new String[] {"1:1.0", "2:1.0", "4:1.0", "7:1.0", "8:1.0"}, 1}); // 4
    udtf.process(new Object[] {new String[] {"13:1.0", "18:1.0", "25:1.0", "27:1.0", "65:1.0",
            "116:1.0", "200:1.0", "468:1.0", "585:1.0", "715:1.0"}, 0});

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {}

    });

    udtf.close();
}
 
Example #12
Source File: TestUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
        @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
    final T udtf;
    try {
        udtf = clazz.newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
        throw new HiveException(e);
    }

    udtf.initialize(ois);

    // serialization after initialization
    byte[] serialized = serializeObjectByKryo(udtf);
    deserializeObjectByKryo(serialized, clazz);

    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    for (Object[] row : rows) {
        udtf.process(row);
    }

    // serialization after processing row
    serialized = serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, clazz);

    udtf.close();
}
 
Example #13
Source File: ConditionalEmitUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws HiveException {
    ConditionalEmitUDTF udtf = new ConditionalEmitUDTF();

    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaBooleanObjectInspector),
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),});

    final List<Object> actual = new ArrayList<>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardObj = (Object[]) input;
            Assert.assertEquals(1, forwardObj.length);
            actual.add(forwardObj[0]);
        }
    });

    udtf.process(
        new Object[] {Arrays.asList(true, false, true), Arrays.asList("one", "two", "three")});

    Assert.assertEquals(Arrays.asList("one", "three"), actual);

    actual.clear();

    udtf.process(
        new Object[] {Arrays.asList(true, true, false), Arrays.asList("one", "two", "three")});
    Assert.assertEquals(Arrays.asList("one", "two"), actual);

    udtf.close();
}
 
Example #14
Source File: MovingAverageUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws HiveException {
    MovingAverageUDTF udtf = new MovingAverageUDTF();

    ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3);

    final List<Double> results = new ArrayList<>();
    udtf.initialize(new ObjectInspector[] {argOI0, argOI1});
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] objs = (Object[]) input;
            Assert.assertEquals(1, objs.length);
            Assert.assertTrue(objs[0] instanceof DoubleWritable);
            double x = ((DoubleWritable) objs[0]).get();
            results.add(x);
        }
    });

    udtf.process(new Object[] {1.f, null});
    udtf.process(new Object[] {2.f, null});
    udtf.process(new Object[] {3.f, null});
    udtf.process(new Object[] {4.f, null});
    udtf.process(new Object[] {5.f, null});
    udtf.process(new Object[] {6.f, null});
    udtf.process(new Object[] {7.f, null});

    Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results);
}
 
Example #15
Source File: GenerateSeriesUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testTwoConstArgs() throws HiveException {
    GenerateSeriesUDTF udtf = new GenerateSeriesUDTF();

    udtf.initialize(new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
                TypeInfoFactory.intTypeInfo, new IntWritable(1)),
            PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
                TypeInfoFactory.intTypeInfo, new IntWritable(3))});

    final List<IntWritable> actual = new ArrayList<>();

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object args) throws HiveException {
            Object[] row = (Object[]) args;
            IntWritable row0 = (IntWritable) row[0];
            actual.add(new IntWritable(row0.get()));
        }
    });

    udtf.process(new Object[] {new IntWritable(1), new IntWritable(3)});

    List<IntWritable> expected =
            Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3));
    Assert.assertEquals(expected, actual);
}
 
Example #16
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testSparseRandomForestClassifierL2Normalized() throws HiveException {
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector});

    udtf.process(new Object[] {new String[] {"1:0.5", "4:0.5", "7:0.5", "12:0.5"}, 1}); // 0
    udtf.process(new Object[] {new String[] {"2:0.5", "4:0.5", "5:0.5", "11:0.5"}, 1}); // 1
    udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "7:0.40824828",
            "113:0.40824828", "497:0.40824828", "635:0.40824828"}, 0}); // 2
    udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "5:0.40824828",
            "7:0.40824828", "10:0.40824828", "14:0.40824828"}, 1}); // 3
    udtf.process(new Object[] {new String[] {"1:0.4472136", "2:0.4472136", "4:0.4472136",
            "7:0.4472136", "8:0.4472136"}, 1}); // 4
    udtf.process(new Object[] {new String[] {"13:0.31622776", "18:0.31622776", "25:0.31622776",
            "27:0.31622776", "65:0.31622776", "116:0.31622776", "200:0.31622776",
            "468:0.31622776", "585:0.31622776", "715:0.31622776"}, 0}); // 5

    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {}

    });

    udtf.close();
}
 
Example #17
Source File: GeneralClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
        @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
    int y = 0;

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
    ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ListObjectInspector featureListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

    udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

    final List<Object> modelFeatures = new ArrayList<Object>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardMapObj = (Object[]) input;
            modelFeatures.add(forwardMapObj[0]);
        }
    });

    udtf.process(new Object[] {x, y});

    udtf.close();

    Assert.assertFalse(modelFeatures.isEmpty());
    for (Object modelFeature : modelFeatures) {
        Assert.assertEquals("All model features must have same type", modelFeatureClass,
            modelFeature.getClass());
    }
}
 
Example #18
Source File: TestUtils.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@SuppressWarnings("deprecation")
public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz,
        @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException {
    final T udtf;
    try {
        udtf = clazz.newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
        throw new HiveException(e);
    }

    udtf.initialize(ois);

    // serialization after initialization
    byte[] serialized = serializeObjectByKryo(udtf);
    deserializeObjectByKryo(serialized, clazz);

    udtf.setCollector(new Collector() {
        public void collect(Object input) throws HiveException {
            // noop
        }
    });

    for (Object[] row : rows) {
        udtf.process(row);
    }

    // serialization after processing row
    serialized = serializeObjectByKryo(udtf);
    TestUtils.deserializeObjectByKryo(serialized, clazz);

    udtf.close();
}
 
Example #19
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
        @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
    float y = 1.f;

    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ListObjectInspector featureListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

    udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

    final List<Object> modelFeatures = new ArrayList<Object>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardMapObj = (Object[]) input;
            modelFeatures.add(forwardMapObj[0]);
        }
    });

    udtf.process(new Object[] {x, y});

    udtf.close();

    Assert.assertFalse(modelFeatures.isEmpty());
    for (Object modelFeature : modelFeatures) {
        Assert.assertEquals("All model features must have same type", modelFeatureClass,
            modelFeature.getClass());
    }
}
 
Example #20
Source File: HiveGenericUDTF.java    From flink with Apache License 2.0 4 votes vote down vote up
@VisibleForTesting
protected final void setCollector(Collector collector) {
	function.setCollector(collector);
}
 
Example #21
Source File: HiveGenericUDTF.java    From flink with Apache License 2.0 4 votes vote down vote up
@VisibleForTesting
protected final void setCollector(Collector collector) {
	function.setCollector(collector);
}
 
Example #22
Source File: RandomForestRegressionUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static RegressionTree.Node getRegressionTreeFromSparseInput()
        throws IOException, ParseException, HiveException {
    String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};

    double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
            {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
            {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
            {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
            {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
            {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
            {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
            {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
            {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

    double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

    RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < x.length; i++) {
        final double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(mhash(featureNames[j]) + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
    return node;
}
 
Example #23
Source File: RandomForestRegressionUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static RegressionTree.Node getRegressionTreeFromDenseInput()
        throws IOException, ParseException, HiveException {
    double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
            {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
            {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
            {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
            {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
            {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
            {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
            {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
            {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

    double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

    RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < x.length; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
    return node;
}
 
Example #24
Source File: RandomForestRegressionUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testSparse() throws IOException, ParseException, HiveException {
    String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};

    double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
            {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
            {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
            {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
            {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
            {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
            {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
            {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
            {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

    double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

    RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < x.length; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(mhash(featureNames[j]) + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example #25
Source File: RandomForestRegressionUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testDense() throws IOException, ParseException, HiveException {
    double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
            {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
            {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
            {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
            {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
            {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
            {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
            {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
            {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

    double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

    RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < x.length; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(49, count.getValue());
}
 
Example #26
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisSparse() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example #27
Source File: GradientTreeBoostingClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIrisDense() throws IOException, ParseException, HiveException {
    URL url = new URL(
        "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final MutableInt count = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            count.addValue(1);
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(490, count.getValue());
}
 
Example #28
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testNews20BinarySparse() throws IOException, ParseException, HiveException {
    final int numTrees = 10;
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-seed 71 -trees " + numTrees);
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    BufferedReader news20 = readFile("news20-small.binary.gz");
    ArrayList<String> features = new ArrayList<String>();
    String line = news20.readLine();
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line, " ");
        int label = Integer.parseInt(tokens.nextToken());
        if (label == -1) {
            label = 0;
        }
        while (tokens.hasMoreTokens()) {
            features.add(tokens.nextToken());
        }
        if (!features.isEmpty()) {
            udtf.process(new Object[] {features, label});
            features.clear();
        }
        line = news20.readLine();
    }
    news20.close();

    final MutableInt count = new MutableInt(0);
    final MutableInt oobErrors = new MutableInt(0);
    final MutableInt oobTests = new MutableInt(0);
    Collector collector = new Collector() {
        public synchronized void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            oobErrors.addValue(((IntWritable) forward[4]).get());
            oobTests.addValue(((IntWritable) forward[5]).get());
            count.addValue(1);
        }
    };
    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(numTrees, count.getValue());
    float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
    Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3);
}
 
Example #29
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
    final int numTrees = 10;
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-stratified_sampling -seed 71 -trees " + numTrees);
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});


    BufferedReader news20 = readFile("news20-multiclass.gz");
    ArrayList<String> features = new ArrayList<String>();
    String line = news20.readLine();
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line, " ");
        int label = Integer.parseInt(tokens.nextToken());
        while (tokens.hasMoreTokens()) {
            features.add(tokens.nextToken());
        }
        Assert.assertFalse(features.isEmpty());
        udtf.process(new Object[] {features, label});

        features.clear();
        line = news20.readLine();
    }
    news20.close();

    final MutableInt count = new MutableInt(0);
    final MutableInt oobErrors = new MutableInt(0);
    final MutableInt oobTests = new MutableInt(0);
    Collector collector = new Collector() {
        public synchronized void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            oobErrors.addValue(((IntWritable) forward[4]).get());
            oobTests.addValue(((IntWritable) forward[5]).get());
            count.addValue(1);
        }
    };
    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(numTrees, count.getValue());
    float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
    // TODO why multi-class classification so bad??
    Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
}
 
Example #30
Source File: RandomForestClassifierUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<String> xi = new ArrayList<String>(x[0].length);
    for (int i = 0; i < size; i++) {
        final double[] row = x[i];
        for (int j = 0; j < row.length; j++) {
            xi.add(j + ":" + row[j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}