Java Code Examples for org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory#getStandardListObjectInspector()

The following examples show how to use org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory#getStandardListObjectInspector() . These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
@Test
public void testDoubleArray() throws Exception {
    ToJsonUDF udf = new ToJsonUDF();

    ObjectInspector[] argOIs =
            new ObjectInspector[] {ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)};
    DeferredObject[] args = new DeferredObject[] {new GenericUDF.DeferredJavaObject(
        WritableUtils.toWritableList(new double[] {0.1, 1.1, 2.1}))};

    udf.initialize(argOIs);
    Text serialized = udf.evaluate(args);

    Assert.assertEquals("[0.1,1.1,2.1]", serialized.toString());

    udf.close();
}
 
Example 2
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    // Check if two arguments were passed
    if (arguments.length != ARG_COUNT) {
        throw new UDFArgumentLengthException(
                "The function split_to_multimap(string, string, string) takes exactly " + ARG_COUNT + " arguments.");
    }

    // Check if two argument is of string
    for (int i = 0; i < 3; i++) {
        if (!ObjectInspectorUtils.compareTypes(PrimitiveObjectInspectorFactory.javaStringObjectInspector, arguments[i])) {
            throw new UDFArgumentTypeException(i,
                    "\"" + PrimitiveObjectInspectorFactory.javaStringObjectInspector.getTypeName() + "\" "
                            + "expected at function split_to_multimap, but "
                            + "\"" + arguments[i].getTypeName() + "\" "
                            + "is found");
        }
    }

    ObjectInspector mapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ObjectInspector mapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);

    return ObjectInspectorFactory.getStandardMapObjectInspector(mapKeyOI, mapValueOI);
}
 
Example 3
@Test
public void testPA1TrainWithParameter() throws UDFArgumentException {
    PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ListObjectInspector intListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(intOI);

    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 0.1");
    /* define aggressive parameter */
    udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});

    /* train weights */
    List<?> features = (List<?>) intListOI.getList(new Object[] {1, 2, 3});
    udtf.train(features, 1);

    /* check weights */
    assertEquals(0.1000000f, udtf.model.get(1).get(), 1e-5f);
    assertEquals(0.1000000f, udtf.model.get(2).get(), 1e-5f);
    assertEquals(0.1000000f, udtf.model.get(3).get(), 1e-5f);
}
 
Example 4
@Test
public void testMapBuild() throws Exception {
    UDFMapBuild udf = new UDFMapBuild();
    ObjectInspector keyArrayOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector valueArrayOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector[] arguments = {keyArrayOI, valueArrayOI};
    udf.initialize(arguments);

    List<String> keyArray = ImmutableList.of("key1", "key2", "key3");
    List<String> valueArray = ImmutableList.of("value1", "value2", "value3");
    DeferredObject keyArrayObj = new DeferredJavaObject(keyArray);
    DeferredObject valueArrayObj = new DeferredJavaObject(valueArray);
    DeferredObject[] args = {keyArrayObj, valueArrayObj};
    LinkedHashMap<String, String> output = (LinkedHashMap<String, String>) udf.evaluate(args);
    LinkedHashMap<String, String> expect = Maps.newLinkedHashMap();
    expect.putAll(ImmutableMap.<String, String>of("key1", "value1", "key2", "value2", "key3", "value3"));

    Assert.assertEquals("map_build() test", true, MapUtils.mapEquals(output, expect));
}
 
Example 5
@Nonnull
private static StructObjectInspector internalMergeOutputOI(
        @CheckForNull PrimitiveObjectInspector[] inputOIs) throws UDFArgumentException {
    Preconditions.checkNotNull(inputOIs);

    final int numOIs = inputOIs.length;
    final List<String> fieldNames = new ArrayList<String>(numOIs);
    final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(numOIs);
    for (int i = 0; i < numOIs; i++) {
        fieldNames.add("f" + String.valueOf(i));
        ObjectInspector elemOI = ObjectInspectorUtils.getStandardObjectInspector(
            inputOIs[i], ObjectInspectorCopyOption.WRITABLE);
        ListObjectInspector listOI =
                ObjectInspectorFactory.getStandardListObjectInspector(elemOI);
        fieldOIs.add(listOI);
    }
    return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
 
Example 6
Source Project: flink   File: HiveInspectors.java    License: Apache License 2.0 5 votes vote down vote up
private static ObjectInspector getObjectInspector(TypeInfo type) {
	switch (type.getCategory()) {

		case PRIMITIVE:
			PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) type;
			return PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(primitiveType);
		case LIST:
			ListTypeInfo listType = (ListTypeInfo) type;
			return ObjectInspectorFactory.getStandardListObjectInspector(
					getObjectInspector(listType.getListElementTypeInfo()));
		case MAP:
			MapTypeInfo mapType = (MapTypeInfo) type;
			return ObjectInspectorFactory.getStandardMapObjectInspector(
					getObjectInspector(mapType.getMapKeyTypeInfo()), getObjectInspector(mapType.getMapValueTypeInfo()));
		case STRUCT:
			StructTypeInfo structType = (StructTypeInfo) type;
			List<TypeInfo> fieldTypes = structType.getAllStructFieldTypeInfos();

			List<ObjectInspector> fieldInspectors = new ArrayList<ObjectInspector>();
			for (TypeInfo fieldType : fieldTypes) {
				fieldInspectors.add(getObjectInspector(fieldType));
			}

			return ObjectInspectorFactory.getStandardStructObjectInspector(
					structType.getAllStructFieldNames(), fieldInspectors);
		default:
			throw new CatalogException("Unsupported Hive type category " + type.getCategory());
	}
}
 
Example 7
@Test
public void testAdamHD() throws IOException, HiveException {
    String filePath = "adam_test_10000.tsv.gz";
    String options =
            "-loss logloss -opt AdamHD -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI,
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

    BufferedReader reader = readFile(filePath);
    for (String line = reader.readLine(); line != null; line = reader.readLine()) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");

        String featureLine = tokenizer.nextToken();
        List<String> X = Arrays.asList(featureLine.split(","));

        String labelLine = tokenizer.nextToken();
        Integer y = Integer.valueOf(labelLine);

        udtf.process(new Object[] {X, y});
    }

    udtf.finalizeTraining();

    Assert.assertTrue(
        "CumulativeLoss is expected to be less than 800: " + udtf.getCumulativeLoss(),
        udtf.getCumulativeLoss() < 800);
}
 
Example 8
@Test
public void testRMSpropGraves() throws IOException, HiveException {
    String filePath = "adam_test_10000.tsv.gz";
    String options =
            "-loss logloss -opt RMSpropGraves -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI,
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

    BufferedReader reader = readFile(filePath);
    for (String line = reader.readLine(); line != null; line = reader.readLine()) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");

        String featureLine = tokenizer.nextToken();
        List<String> X = Arrays.asList(featureLine.split(","));

        String labelLine = tokenizer.nextToken();
        Integer y = Integer.valueOf(labelLine);

        udtf.process(new Object[] {X, y});
    }

    udtf.finalizeTraining();

    Assert.assertTrue(
        "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
        udtf.getCumulativeLoss() < 1200);
}
 
Example 9
@Test
public void testAdamAmsgrad() throws IOException, HiveException {
    String filePath = "adam_test_10000.tsv.gz";
    String options =
            "-loss logloss -opt Adam -amsgrad -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI,
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

    BufferedReader reader = readFile(filePath);
    for (String line = reader.readLine(); line != null; line = reader.readLine()) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");

        String featureLine = tokenizer.nextToken();
        List<String> X = Arrays.asList(featureLine.split(","));

        String labelLine = tokenizer.nextToken();
        Integer y = Integer.valueOf(labelLine);

        udtf.process(new Object[] {X, y});
    }

    udtf.finalizeTraining();

    Assert.assertTrue(
        "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
        udtf.getCumulativeLoss() < 1200);
}
 
Example 10
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    // Check if two arguments were passed
    if (arguments.length != ARG_COUNT) {
        throw new UDFArgumentLengthException(
                "The function array_slice(array, start, length) takes exactly " + ARG_COUNT + " arguments.");
    }

    // Check if ARRAY_IDX argument is of category LIST
    if (!arguments[ARRAY_IDX].getCategory().equals(ObjectInspector.Category.LIST)) {
        throw new UDFArgumentTypeException(ARRAY_IDX,
                "\"" + org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME + "\" "
                        + "expected at function array_slice, but "
                        + "\"" + arguments[ARRAY_IDX].getTypeName() + "\" "
                        + "is found");
    }

    arrayOI = (ListObjectInspector) arguments[ARRAY_IDX];
    arrayElementOI = arrayOI.getListElementObjectInspector();

    ObjectInspector expectOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;

    // Check if value and expect are of same type
    for (int i = 1; i < 3; i++) {
        if (!ObjectInspectorUtils.compareTypes(expectOI, arguments[i])) {
            throw new UDFArgumentTypeException(i,
                    "\"" + expectOI.getTypeName() + "\""
                            + " expected at function array_slice, but "
                            + "\"" + arguments[i].getTypeName() + "\""
                            + " is found");
        }
    }

    return ObjectInspectorFactory.getStandardListObjectInspector(arrayElementOI);
}
 
Example 11
Source Project: hive-funnel-udf   File: FunnelTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test(expected = UDFArgumentTypeException.class)
public void testComplexParamPosition1() throws HiveException {
    Funnel udaf = new Funnel();
    ObjectInspector[] inputObjectInspectorList = new ObjectInspector[]{
        ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaLongObjectInspector),
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        PrimitiveObjectInspectorFactory.javaStringObjectInspector
    };

    GenericUDAFParameterInfo paramInfo = new SimpleGenericUDAFParameterInfo(inputObjectInspectorList, false, false);
    GenericUDAFEvaluator udafEvaluator = udaf.getEvaluator(paramInfo);
}
 
Example 12
@Test
public void testAdamInvScaleEta() throws IOException, HiveException {
    String filePath = "adam_test_10000.tsv.gz";
    String options =
            "-eta inv -eta0 0.1 -loss logloss -opt Adam -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI,
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

    BufferedReader reader = readFile(filePath);
    for (String line = reader.readLine(); line != null; line = reader.readLine()) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");

        String featureLine = tokenizer.nextToken();
        List<String> X = Arrays.asList(featureLine.split(","));

        String labelLine = tokenizer.nextToken();
        Integer y = Integer.valueOf(labelLine);

        udtf.process(new Object[] {X, y});
    }

    udtf.finalizeTraining();

    Assert.assertTrue(
        "CumulativeLoss is expected to be less than 900: " + udtf.getCumulativeLoss(),
        udtf.getCumulativeLoss() < 900);
}
 
Example 13
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    if (argOIs.length < 2) {
        throw new UDFArgumentException("Expecting at least two arrays as arguments");
    }

    ListObjectInspector[] listOIs = new ListObjectInspector[argOIs.length];
    ListObjectInspector arg0OI = HiveUtils.asListOI(argOIs[0]);
    listOIs[0] = arg0OI;
    ObjectInspector arg0ElemOI = arg0OI.getListElementObjectInspector();

    for (int i = 1; i < argOIs.length; ++i) {
        ListObjectInspector checkOI = HiveUtils.asListOI(argOIs[i]);
        if (!ObjectInspectorUtils.compareTypes(arg0ElemOI,
            checkOI.getListElementObjectInspector())) {
            throw new UDFArgumentException("Array types does not match: " + arg0OI.getTypeName()
                    + " != " + checkOI.getTypeName());
        }
        listOIs[i] = checkOI;
    }

    this._listOIs = listOIs;

    return ObjectInspectorFactory.getStandardListObjectInspector(
        ObjectInspectorUtils.getStandardObjectInspector(arg0ElemOI,
            ObjectInspectorCopyOption.WRITABLE));
}
 
Example 14
@Test
public void testAdaDeltaL1() throws IOException, HiveException {
    String filePath = "adam_test_10000.tsv.gz";
    String options =
            "-loss logloss -opt adadelta -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI,
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

    BufferedReader reader = readFile(filePath);
    for (String line = reader.readLine(); line != null; line = reader.readLine()) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");

        String featureLine = tokenizer.nextToken();
        List<String> X = Arrays.asList(featureLine.split(","));

        String labelLine = tokenizer.nextToken();
        Integer y = Integer.valueOf(labelLine);

        udtf.process(new Object[] {X, y});
    }

    udtf.finalizeTraining();

    Assert.assertTrue(
        "CumulativeLoss is expected to be less than 1500: " + udtf.getCumulativeLoss(),
        udtf.getCumulativeLoss() < 1500);
}
 
Example 15
Source Project: incubator-hivemall   File: ArangeUDF.java    License: Apache License 2.0 4 votes vote down vote up
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    switch (argOIs.length) {
        case 1:
            if (!HiveUtils.isIntegerOI(argOIs[0])) {
                throw new UDFArgumentException(
                    "arange(int stop) expects integer for the 1st argument: "
                            + argOIs[0].getTypeName());
            }
            this.stopOI = HiveUtils.asIntegerOI(argOIs[0]);
            break;
        case 3:
            if (!HiveUtils.isIntegerOI(argOIs[2])) {
                throw new UDFArgumentException(
                    "arange(int start, int stop, int step) expects integer for the 3rd argument: "
                            + argOIs[2].getTypeName());
            }
            this.stepOI = HiveUtils.asIntegerOI(argOIs[2]);
            // fall through
        case 2:
            if (!HiveUtils.isIntegerOI(argOIs[0])) {
                throw new UDFArgumentException(
                    "arange(int start, int stop) expects integer for the 1st argument: "
                            + argOIs[0].getTypeName());
            }
            this.startOI = HiveUtils.asIntegerOI(argOIs[0]);
            if (!HiveUtils.isIntegerOI(argOIs[1])) {
                throw new UDFArgumentException(
                    "arange(int start, int stop) expects integer for the 2nd argument: "
                            + argOIs[1].getTypeName());
            }
            this.stopOI = HiveUtils.asIntegerOI(argOIs[1]);
            break;
        default:
            throw new UDFArgumentException(
                "arange([int start=0, ] int stop, [int step=1]) takes 1~3 arguments: "
                        + argOIs.length);
    }

    return ObjectInspectorFactory.getStandardListObjectInspector(
        ObjectInspectorUtils.getStandardObjectInspector(
            PrimitiveObjectInspectorFactory.writableIntObjectInspector));
}
 
Example 16
@Test
public void testClassification() throws HiveException {
    final int ROW = 10, COL = 40;

    FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
    ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-adareg -int_feature -factors 20 -classification -seed 31 -iters 10");
    udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
    FactorizationMachineModel model = udtf.initModel(udtf._params);
    Assert.assertTrue("Actual class: " + model.getClass().getName(),
        model instanceof FMIntFeatureMapModel);

    float accuracy = 0.f;
    final Random rnd = new Random(201L);
    for (int numberOfIteration = 0; numberOfIteration < 10000; numberOfIteration++) {
        ArrayList<IntFeature[]> fArrayList = new ArrayList<IntFeature[]>();
        ArrayList<Double> ans = new ArrayList<Double>();
        for (int i = 0; i < ROW; i++) {
            ArrayList<IntFeature> feature = new ArrayList<IntFeature>();
            for (int j = 1; j <= COL; j++) {
                if (i < (0.5f * ROW)) {
                    if (j == 1) {
                        feature.add(new IntFeature(j, 1.d));
                    } else if (j < 0.5 * COL) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new IntFeature(j, rnd.nextDouble()));
                        }
                    }
                } else {
                    if (j > 0.5f * COL) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new IntFeature(j, rnd.nextDouble()));
                        }
                    }
                }
            }
            IntFeature[] x = new IntFeature[feature.size()];
            feature.toArray(x);
            fArrayList.add(x);

            final double y;
            if (i < ROW * 0.5f) {
                y = -1.0d;
            } else {
                y = 1.0d;
            }
            ans.add(y);

            udtf.process(new Object[] {toStringArray(x), y});
        }
        int bingo = 0;
        int total = fArrayList.size();
        for (int i = 0; i < total; i++) {
            double tmpAns = ans.get(i);
            if (tmpAns < 0) {
                tmpAns = 0;
            } else {
                tmpAns = 1;
            }
            double p = model.predict(fArrayList.get(i));
            int predicted = p > 0.5 ? 1 : 0;
            if (predicted == tmpAns) {
                bingo++;
            }
        }
        accuracy = bingo / (float) total;
        println("Accuracy = " + accuracy);
    }
    udtf.runTrainingIteration(10);
    Assert.assertTrue(accuracy > 0.95f);
}
 
Example 17
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
    super.init(mode, OIs);

    if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
        weightOI = HiveUtils.asDoubleCompatibleOI(OIs[0]);

        // set const values
        nBins = HiveUtils.getConstInt(OIs[1]);
        if (OIs.length == 3) {
            autoShrink = HiveUtils.getConstBoolean(OIs[2]);
        }

        // check value of `num_of_bins`
        if (nBins < 2) {
            throw new UDFArgumentException(
                "Only greater than or equal to 2 is accepted but " + nBins
                        + " was passed as `num_of_bins`.");
        }

        quantiles = getQuantiles();
    } else {
        structOI = (StructObjectInspector) OIs[0];
        autoShrinkField = structOI.getStructFieldRef("autoShrink");
        histogramField = structOI.getStructFieldRef("histogram");
        quantilesField = structOI.getStructFieldRef("quantiles");
        autoShrinkOI =
                (WritableBooleanObjectInspector) autoShrinkField.getFieldObjectInspector();
        histogramOI =
                (StandardListObjectInspector) histogramField.getFieldObjectInspector();
        quantilesOI =
                (StandardListObjectInspector) quantilesField.getFieldObjectInspector();
        histogramElOI =
                (WritableDoubleObjectInspector) histogramOI.getListElementObjectInspector();
        quantileOI =
                (WritableDoubleObjectInspector) quantilesOI.getListElementObjectInspector();
    }

    if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
        final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector);
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));

        return ObjectInspectorFactory.getStandardStructObjectInspector(
            Arrays.asList("autoShrink", "histogram", "quantiles"), fieldOIs);
    } else {
        return ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    }
}
 
Example 18
@Test
public void snrMultipleClassCornerCase0() throws Exception {
    final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
    final ObjectInspector[] OIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
    final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator =
            (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(
                new SimpleGenericUDAFParameterInfo(OIs, false, false));
    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
    final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg =
            (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
    evaluator.reset(agg);

    // all c0[0] and c1[0] are equal
    // all c1[1] and c2[1] are equal
    // all c*[2] are equal
    // all c*[3] are different
    final double[][] features = new double[][] {{3.5, 1.4, 0.3, 5.1}, {3.5, 1.5, 0.3, 5.2},
            {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 6.3}};

    final int[][] labels = new int[][] {{1, 0, 0}, {1, 0, 0}, // class `0`
            {0, 1, 0}, {0, 1, 0}, // class `1`
            {0, 0, 1}}; // class `2`, only single entry

    for (int i = 0; i < features.length; i++) {
        final List<IntWritable> labelList = new ArrayList<IntWritable>();
        for (int label : labels[i]) {
            labelList.add(new IntWritable(label));
        }
        evaluator.iterate(agg,
            new Object[] {WritableUtils.toWritableList(features[i]), labelList});
    }

    @SuppressWarnings("unchecked")
    final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
    final int size = resultObj.size();
    final double[] result = new double[size];
    for (int i = 0; i < size; i++) {
        result[i] = resultObj.get(i).get();
    }

    final double[] answer = new double[] {Double.POSITIVE_INFINITY, 121.99999999999989, 0.d,
            28.761904761904734};

    Assert.assertArrayEquals(answer, result, 1e-5);
}
 
Example 19
@Test
public void testAdaptiveRegularization() throws HiveException, IOException {
    println("Adaptive regularization test");

    final String options = "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.01 -seed 31 ";

    FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
    ObjectInspector[] argOIs = new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, options)};

    udtf.initialize(argOIs);

    BufferedReader data = readFile("5107786.txt.gz");
    List<List<String>> featureVectors = new ArrayList<>();
    List<Double> ys = new ArrayList<>();
    String line = data.readLine();
    while (line != null) {
        StringTokenizer tokenizer = new StringTokenizer(line, " ");
        double y = Double.parseDouble(tokenizer.nextToken());
        List<String> features = new ArrayList<String>();
        while (tokenizer.hasMoreTokens()) {
            String f = tokenizer.nextToken();
            features.add(f);
        }
        udtf.process(new Object[] {features, y});
        featureVectors.add(features);
        ys.add(y);
        line = data.readLine();
    }
    udtf.finalizeTraining();
    data.close();

    double loss = udtf._cvState.getAverageLoss(featureVectors.size());
    println("Average loss without adaptive regularization: " + loss);

    // train with adaptive regularization
    udtf = new FactorizationMachineUDTF();
    argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        options + "-adaptive_regularization -validation_threshold 1");
    udtf.initialize(argOIs);
    udtf.initModel(udtf._params);
    for (int i = 0, n = featureVectors.size(); i < n; i++) {
        udtf.process(new Object[] {featureVectors.get(i), ys.get(i)});
    }
    udtf.finalizeTraining();

    double loss_adareg = udtf._cvState.getAverageLoss(featureVectors.size());
    println("Average loss with adaptive regularization: " + loss_adareg);
    Assert.assertTrue("Adaptive regularization should achieve lower loss", loss > loss_adareg);
}
 
Example 20
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    if (argOIs.length != 2) {
        throw new UDFArgumentLengthException("Specify two arguments :" + argOIs.length);
    }

    if (HiveUtils.isListOI(argOIs[0]) && HiveUtils.isMapOI(argOIs[1])) {
        // feature_binning(array<features::string> features, map<string, array<number>> quantiles_map)

        if (!HiveUtils.isStringOI(
            ((ListObjectInspector) argOIs[0]).getListElementObjectInspector())) {
            throw new UDFArgumentTypeException(0,
                "Only array<string> type argument can be accepted but "
                        + argOIs[0].getTypeName() + " was passed as `features`");
        }
        featuresOI = HiveUtils.asListOI(argOIs[0]);
        featureOI = HiveUtils.asStringOI(featuresOI.getListElementObjectInspector());

        quantilesMapOI = HiveUtils.asMapOI(argOIs[1]);
        if (!HiveUtils.isStringOI(quantilesMapOI.getMapKeyObjectInspector())
                || !HiveUtils.isListOI(quantilesMapOI.getMapValueObjectInspector())
                || !HiveUtils.isNumberOI(
                    ((ListObjectInspector) quantilesMapOI.getMapValueObjectInspector()).getListElementObjectInspector())) {
            throw new UDFArgumentTypeException(1,
                "Only map<string, array<number>> type argument can be accepted but "
                        + argOIs[1].getTypeName() + " was passed as `quantiles_map`");
        }
        keyOI = HiveUtils.asStringOI(quantilesMapOI.getMapKeyObjectInspector());
        quantilesOI = HiveUtils.asListOI(quantilesMapOI.getMapValueObjectInspector());
        quantileOI =
                HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector());

        multiple = true;

        return ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    } else if (HiveUtils.isPrimitiveOI(argOIs[0]) && HiveUtils.isListOI(argOIs[1])) {
        // feature_binning(number weight, array<number> quantiles)

        weightOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]);

        quantilesOI = HiveUtils.asListOI(argOIs[1]);
        if (!HiveUtils.isNumberOI(quantilesOI.getListElementObjectInspector())) {
            throw new UDFArgumentTypeException(1,
                "Only array<number> type argument can be accepted but "
                        + argOIs[1].getTypeName() + " was passed as `quantiles`");
        }
        quantileOI =
                HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector());

        multiple = false;

        return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
    } else {
        throw new UDFArgumentTypeException(0,
            "Only <array<features::string>, map<string, array<number>>> "
                    + "or <number, array<number>> type arguments can be accepted but <"
                    + argOIs[0].getTypeName() + ", " + argOIs[1].getTypeName()
                    + "> was passed.");
    }
}