Java Code Examples for org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory#javaFloatObjectInspector()

The following examples show how to use org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory#javaFloatObjectInspector() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 6 votes vote down vote up
@Test
public void testNoOptions() throws Exception {
    List<String> x = Arrays.asList("1:-2", "2:-1");
    float y = 0.f;

    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);

    udtf.initialize(new ObjectInspector[] {stringListOI, intOI});

    udtf.process(new Object[] {x, y});

    udtf.finalizeTraining();

    float predicted = udtf.predict(udtf.parseFeatures(x));
    Assert.assertEquals(y, predicted, 1E-5);
}
 
Example 2
Source File: FeatureUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testIntFloat() throws Exception {
    ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector weightOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    udf.initialize(new ObjectInspector[] {featureOI, weightOI});

    Text ret = udf.evaluate(new GenericUDF.DeferredObject[] {new DeferredJavaObject(1),
            new DeferredJavaObject(2.5f)});

    Assert.assertEquals("1:2.5", ret.toString());
}
 
Example 3
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test(expected = UDFArgumentException.class)
public void testUnsupportedRegularization() throws Exception {
    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reg UnsupportedReg");

    udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
}
 
Example 4
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test(expected = UDFArgumentException.class)
public void testInvalidLossFunction() throws Exception {
    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss HingeLoss");

    udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
}
 
Example 5
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test(expected = UDFArgumentException.class)
public void testUnsupportedLossFunction() throws Exception {
    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss");

    udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
}
 
Example 6
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test(expected = UDFArgumentException.class)
public void testUnsupportedOptimizer() throws Exception {
    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-opt UnsupportedOpt");

    udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
}
 
Example 7
Source File: AdaGradUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@SuppressWarnings("deprecation")
@Test
public void testInitialize() throws UDFArgumentException {
    AdaGradUDTF udtf = new AdaGradUDTF();
    ObjectInspector labelOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ListObjectInspector intListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(intOI);

    /* test for INT_TYPE_NAME feature */
    StructObjectInspector intListSOI =
            udtf.initialize(new ObjectInspector[] {intListOI, labelOI});
    assertEquals("struct<feature:int,weight:float>", intListSOI.getTypeName());

    /* test for STRING_TYPE_NAME feature */
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    StructObjectInspector stringListSOI =
            udtf.initialize(new ObjectInspector[] {stringListOI, labelOI});
    assertEquals("struct<feature:string,weight:float>", stringListSOI.getTypeName());

    /* test for BIGINT_TYPE_NAME feature */
    ObjectInspector longOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
    ListObjectInspector longListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(longOI);
    StructObjectInspector longListSOI =
            udtf.initialize(new ObjectInspector[] {longListOI, labelOI});
    assertEquals("struct<feature:bigint,weight:float>", longListSOI.getTypeName());
}
 
Example 8
Source File: MovingAverageUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void test() throws HiveException {
    MovingAverageUDTF udtf = new MovingAverageUDTF();

    ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3);

    final List<Double> results = new ArrayList<>();
    udtf.initialize(new ObjectInspector[] {argOI0, argOI1});
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] objs = (Object[]) input;
            Assert.assertEquals(1, objs.length);
            Assert.assertTrue(objs[0] instanceof DoubleWritable);
            double x = ((DoubleWritable) objs[0]).get();
            results.add(x);
        }
    });

    udtf.process(new Object[] {1.f, null});
    udtf.process(new Object[] {2.f, null});
    udtf.process(new Object[] {3.f, null});
    udtf.process(new Object[] {4.f, null});
    udtf.process(new Object[] {5.f, null});
    udtf.process(new Object[] {6.f, null});
    udtf.process(new Object[] {7.f, null});

    Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results);
}
 
Example 9
Source File: FeatureUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testStringFloat() throws Exception {
    ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ObjectInspector weightOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    udf.initialize(new ObjectInspector[] {featureOI, weightOI});

    Text ret = udf.evaluate(new GenericUDF.DeferredObject[] {new DeferredJavaObject("f1"),
            new DeferredJavaObject(2.5f)});

    Assert.assertEquals("f1:2.5", ret.toString());
}
 
Example 10
Source File: FeatureUDFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testLongFloat() throws Exception {
    ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
    ObjectInspector weightOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    udf.initialize(new ObjectInspector[] {featureOI, weightOI});

    Text ret = udf.evaluate(new GenericUDF.DeferredObject[] {new DeferredJavaObject(1L),
            new DeferredJavaObject(2.5f)});

    Assert.assertEquals("1:2.5", ret.toString());
}
 
Example 11
Source File: GeneralRegressorUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
        @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
    float y = 1.f;

    GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
    ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ListObjectInspector featureListOI =
            ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

    udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

    final List<Object> modelFeatures = new ArrayList<Object>();
    udtf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            Object[] forwardMapObj = (Object[]) input;
            modelFeatures.add(forwardMapObj[0]);
        }
    });

    udtf.process(new Object[] {x, y});

    udtf.close();

    Assert.assertFalse(modelFeatures.isEmpty());
    for (Object modelFeature : modelFeatures) {
        Assert.assertEquals("All model features must have same type", modelFeatureClass,
            modelFeature.getClass());
    }
}
 
Example 12
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testIterationsWithoutFile() throws HiveException {
    println("--------------------------\n testIterationsWithoutFile()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 100;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];
    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            args[0] = row;
            args[1] = col;
            args[2] = (float) rating[row][col];
            mf.process(args);
        }
    }
    mf.runIterativeTraining(iters);

    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            double predicted = mf.predict(row, col);
            print(rating[row][col] + "[" + predicted + "]\t");
            Assert.assertEquals(rating[row][col], predicted, 0.2d);
        }
        println();
    }
}
 
Example 13
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testGaussianInit() throws HiveException {
    println("--------------------------\n testGaussianInit()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -rankinit gaussian"));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    mf.initialize(argOIs);
    Assert.assertTrue(mf.rankInit == RankInitScheme.gaussian);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];
    final int num_iters = 100;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
            }
        }
    }
    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            double predicted = mf.predict(row, col);
            print(rating[row][col] + "[" + predicted + "]\t");
            Assert.assertEquals(rating[row][col], predicted, 0.2d);
        }
        println();
    }
}
 
Example 14
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testRandInit() throws HiveException {
    println("--------------------------\n testRandInit()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -rankinit random"));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    mf.initialize(argOIs);
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];
    final int num_iters = 100;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
            }
        }
    }
    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            double predicted = mf.predict(row, col);
            print(rating[row][col] + "[" + predicted + "]\t");
            Assert.assertEquals(rating[row][col], predicted, 0.2d);
        }
        println();
    }
}
 
Example 15
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 5 votes vote down vote up
@Test
public void testDefaultInit() throws HiveException {
    println("--------------------------\n testGaussian()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    //ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String("-factor 3 -eta 0.0002"));
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String("-factor 3"));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    mf.initialize(argOIs);
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];
    final int num_iters = 100;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                //print(row + "," + col + ",");
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                //println((float) rating[row][col]);
                mf.process(args);
            }
        }
    }
    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            double predicted = mf.predict(row, col);
            print(rating[row][col] + "[" + predicted + "]\t");
            Assert.assertEquals(rating[row][col], predicted, 0.2d);
        }
        println();
    }
}
 
Example 16
Source File: HiveJdbcBridgeUtils.java    From HiveJdbcStorageHandler with Apache License 2.0 5 votes vote down vote up
public static ObjectInspector getObjectInspector(int sqlType, String hiveType)
        throws SerDeException {
    switch(sqlType) {
        case Types.VARCHAR:
            return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        case Types.FLOAT:
            return PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
        case Types.DOUBLE:
            return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        case Types.BOOLEAN:
            return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector;
        case Types.TINYINT:
            return PrimitiveObjectInspectorFactory.javaByteObjectInspector;
        case Types.SMALLINT:
            return PrimitiveObjectInspectorFactory.javaShortObjectInspector;
        case Types.INTEGER:
            return PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        case Types.BIGINT:
            return PrimitiveObjectInspectorFactory.javaLongObjectInspector;
        case Types.TIMESTAMP:
            return PrimitiveObjectInspectorFactory.javaTimestampObjectInspector;
        case Types.BINARY:
            return PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector;
        case Types.ARRAY:
            String hiveElemType = hiveType.substring(hiveType.indexOf('<') + 1, hiveType.indexOf('>')).trim();
            int sqlElemType = hiveTypeToSqlType(hiveElemType);
            ObjectInspector listElementOI = getObjectInspector(sqlElemType, hiveElemType);
            return ObjectInspectorFactory.getStandardListObjectInspector(listElementOI);
        default:
            throw new SerDeException("Cannot find getObjectInspecto for: " + hiveType);
    }
}
 
Example 17
Source File: HiveTypeSystem.java    From transport with BSD 2-Clause "Simplified" License 4 votes vote down vote up
@Override
protected ObjectInspector createFloatType() {
  return PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
}
 
Example 18
Source File: OrcUtils.java    From spork with Apache License 2.0 4 votes vote down vote up
public static ObjectInspector createObjectInspector(TypeInfo info) {
    switch (info.getCategory()) {
    case PRIMITIVE:
      switch (((PrimitiveTypeInfo) info).getPrimitiveCategory()) {
        case FLOAT:
          return PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
        case DOUBLE:
          return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        case BOOLEAN:
          return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector;
        case INT:
          return PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        case LONG:
          return PrimitiveObjectInspectorFactory.javaLongObjectInspector;
        case STRING:
          return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        case TIMESTAMP:
          return new PigJodaTimeStampObjectInspector();
        case DECIMAL:
          return new PigDecimalObjectInspector();
        case BINARY:
          return new PigDataByteArrayObjectInspector();
        case DATE:
        case VARCHAR:
        case BYTE:
        case SHORT:
            throw new IllegalArgumentException("Should never happen, " + 
                    (((PrimitiveTypeInfo) info).getPrimitiveCategory()) +
                    "is not valid Pig primitive data type");
        default:
            throw new IllegalArgumentException("Unknown primitive type " +
                    ((PrimitiveTypeInfo) info).getPrimitiveCategory());
      }
    case STRUCT:
      return new PigStructInspector((StructTypeInfo) info);
    case MAP:
      return new PigMapObjectInspector((MapTypeInfo) info);
    case LIST:
      return new PigListObjectInspector((ListTypeInfo) info);
    default:
      throw new IllegalArgumentException("Unknown type " +
        info.getCategory());
  }
}
 
Example 19
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testFileBackedIterationsCloseWithConverge() throws HiveException {
    println("--------------------------\n testFileBackedIterationsCloseWithConverge()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 10;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 500;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }

    File tmpFile = mf.fileIO.getFile();
    mf.close();
    Assert.assertTrue(mf.count < trainingExamples * iters);
    Assert.assertEquals(5, numCollected.intValue());
    Assert.assertFalse(tmpFile.exists());
}
 
Example 20
Source File: MatrixFactorizationSGDUDTFTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
@Test
public void testIterationsCloseWithoutFile() throws HiveException {
    println("--------------------------\n testIterationsCloseWithoutFile()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 3;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        new String("-factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = {{5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 100;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }
    mf.close();
    Assert.assertEquals(trainingExamples * iters, mf.count);
    Assert.assertEquals(5, numCollected.intValue());
}