Java Code Examples for org.apache.spark.sql.types.DataTypes#DoubleType

The following examples show how to use org.apache.spark.sql.types.DataTypes#DoubleType . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DBClientWrapper.java    From spark-data-sources with MIT License 6 votes vote down vote up
public static edb.common.Row sparkToDBRow(org.apache.spark.sql.Row row, StructType type) {
    edb.common.Row dbRow = new edb.common.Row();
    StructField[] fields = type.fields();
    for (int i = 0; i < type.size(); i++) {
        StructField sf = fields[i];
        if (sf.dataType() == DataTypes.StringType) {
            dbRow.addField(new edb.common.Row.StringField(sf.name(), row.getString(i)));
        } else if (sf.dataType() == DataTypes.DoubleType) {
            dbRow.addField(new edb.common.Row.DoubleField(sf.name(), row.getDouble(i)));
        } else if (sf.dataType() == DataTypes.LongType) {
            dbRow.addField(new edb.common.Row.Int64Field(sf.name(), row.getLong(i)));
        } else {
            // TODO: type leakage
        }
    }

    return dbRow;
}
 
Example 2
Source File: TestRangeRowRule.java    From envelope with Apache License 2.0 6 votes vote down vote up
@Test
public void testRangeDataTypes() throws Exception {
  Config config = ConfigUtils.configFromResource("/dq/dq-range-rules.conf").getConfig("steps");
  StructType schema = new StructType(new StructField[] {
    new StructField("fa", DataTypes.LongType, false, Metadata.empty()),
    new StructField("fi", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("fl", DataTypes.LongType, false, Metadata.empty()),
    new StructField("ff", DataTypes.FloatType, false, Metadata.empty()),
    new StructField("fe", DataTypes.DoubleType, false, Metadata.empty()),
    new StructField("fd", DataTypes.createDecimalType(), false, Metadata.empty())
  });
  Row row = new RowWithSchema(schema, new Long(2), 2, new Long(2), new Float(2.0), 2.0, new BigDecimal("2.0"));
    
  ConfigObject rro =  config.getObject("dq1.deriver.rules") ;
  for ( String rulename : rro.keySet() ) {
    Config rrc = rro.toConfig().getConfig(rulename);
    RangeRowRule rrr = new RangeRowRule() ;
    rrr.configure(rrc);
    rrr.configureName(rulename);
    assertTrue("Row should pass rule " + rulename, rrr.check(row));
  }
}
 
Example 3
Source File: DataFrames.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Convert a datavec schema to a
 * struct type in spark
 *
 * @param schema the schema to convert
 * @return the datavec struct type
 */
public static StructType fromSchema(Schema schema) {
    StructField[] structFields = new StructField[schema.numColumns()];
    for (int i = 0; i < structFields.length; i++) {
        switch (schema.getColumnTypes().get(i)) {
            case Double:
                structFields[i] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                break;
            case Integer:
                structFields[i] =
                                new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                break;
            case Long:
                structFields[i] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                break;
            case Float:
                structFields[i] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                break;
            default:
                throw new IllegalStateException(
                                "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
        }
    }
    return new StructType(structFields);
}
 
Example 4
Source File: TypeCastStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
private DataType mapDataType(List<StructField> datasetFields, String column, String typeConfig) {

        DataType currentDatatype = getCurrentDataType(datasetFields, column);

        // when typeConfig is null (no config for this column), return the current DataType
        if(typeConfig == null) {
            return currentDatatype;
        }

        switch (typeConfig) {
            case "integer":
                return DataTypes.IntegerType;
            case "long":
                return DataTypes.LongType;
            case "double":
                return DataTypes.DoubleType;
            case "boolean":
                return DataTypes.BooleanType;
            case "date":
                return DataTypes.DateType;
            case "timestamp":
                return DataTypes.TimestampType;
            default:
                return DataTypes.StringType;
        }
    }
 
Example 5
Source File: VectorBinarizerBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorBinarizerDense() {
    // prepare data

    JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList(
            RowFactory.create(0d, 1d, new DenseVector(new double[]{-2d, -3d, -4d, -1d, 6d, -7d, 8d, 0d, 0d, 0d, 0d, 0d})),
            RowFactory.create(1d, 2d, new DenseVector(new double[]{4d, -5d, 6d, 7d, -8d, 9d, -10d, 0d, 0d, 0d, 0d, 0d})),
            RowFactory.create(2d, 3d, new DenseVector(new double[]{-5d, 6d, -8d, 9d, 10d, 11d, 12d, 0d, 0d, 0d, 0d, 0d}))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("vector1", new VectorUDT(), false, Metadata.empty())
    });

    DataFrame df = sqlContext.createDataFrame(jrdd, schema);
    VectorBinarizer vectorBinarizer = new VectorBinarizer()
            .setInputCol("vector1")
            .setOutputCol("binarized")
            .setThreshold(2d);


    //Export this model
    byte[] exportedModel = ModelExporter.export(vectorBinarizer, df);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
    //compare predictions
    Row[] sparkOutput = vectorBinarizer.transform(df).orderBy("id").select("id", "value1", "vector1", "binarized").collect();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<>();
        data.put(vectorBinarizer.getInputCol(), ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(vectorBinarizer.getOutputCol());
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 6
Source File: FirstPrediction.java    From net.jgp.labs.spark with Apache License 2.0 5 votes vote down vote up
private void start() {
  SparkSession spark = SparkSession.builder().appName("First Prediction")
      .master("local").getOrCreate();

  StructType schema = new StructType(
      new StructField[] { new StructField("label", DataTypes.DoubleType,
          false, Metadata.empty()),
          new StructField("features", new VectorUDT(), false, Metadata
              .empty()), });

  // TODO this example is not working yet
}
 
Example 7
Source File: JavaChiSqSelectorExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaChiSqSelectorExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
    RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),
    RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),
    RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("features", new VectorUDT(), false, Metadata.empty()),
    new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty())
  });

  Dataset<Row> df = spark.createDataFrame(data, schema);

  ChiSqSelector selector = new ChiSqSelector()
    .setNumTopFeatures(1)
    .setFeaturesCol("features")
    .setLabelCol("clicked")
    .setOutputCol("selectedFeatures");

  Dataset<Row> result = selector.fit(df).transform(df);

  System.out.println("ChiSqSelector output with top " + selector.getNumTopFeatures()
      + " features selected");
  result.show();

  // $example off$
  spark.stop();
}
 
Example 8
Source File: JavaQuantileDiscretizerExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaQuantileDiscretizerExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
    RowFactory.create(0, 18.0),
    RowFactory.create(1, 19.0),
    RowFactory.create(2, 8.0),
    RowFactory.create(3, 5.0),
    RowFactory.create(4, 2.2)
  );

  StructType schema = new StructType(new StructField[]{
    new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("hour", DataTypes.DoubleType, false, Metadata.empty())
  });

  Dataset<Row> df = spark.createDataFrame(data, schema);
  // $example off$
  // Output of QuantileDiscretizer for such small datasets can depend on the number of
  // partitions. Here we force a single partition to ensure consistent results.
  // Note this is not necessary for normal use cases
  df = df.repartition(1);
  // $example on$
  QuantileDiscretizer discretizer = new QuantileDiscretizer()
    .setInputCol("hour")
    .setOutputCol("result")
    .setNumBuckets(3);

  Dataset<Row> result = discretizer.fit(df).transform(df);
  result.show();
  // $example off$
  spark.stop();
}
 
Example 9
Source File: MinMaxScalerBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testMinMaxScaler() {
    //prepare data
    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
            RowFactory.create(1.0, Vectors.dense(data[0])),
            RowFactory.create(2.0, Vectors.dense(data[1])),
            RowFactory.create(3.0, Vectors.dense(data[2])),
            RowFactory.create(4.0, Vectors.dense(data[3]))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    Dataset<Row> df = spark.createDataFrame(jrdd, schema);

    //train model in spark
    MinMaxScalerModel sparkModel = new MinMaxScaler()
            .setInputCol("features")
            .setOutputCol("scaled")
            .setMin(-5)
            .setMax(5)
            .fit(df);


    //Export model, import it back and get transformer
    byte[] exportedModel = ModelExporter.export(sparkModel);
    final Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    List<Row> sparkOutput = sparkModel.transform(df).orderBy("label").select("features", "scaled").collectAsList();
    assertCorrectness(sparkOutput, expected, transformer);
}
 
Example 10
Source File: FrameRDDConverterUtils.java    From systemds with Apache License 2.0 5 votes vote down vote up
/**
 * This function will convert Frame schema into DataFrame schema 
 * 
 * @param fschema frame schema
 * @param containsID true if contains ID column
 * @return Spark StructType of StructFields representing schema
 */
public static StructType convertFrameSchemaToDFSchema(ValueType[] fschema, boolean containsID)
{
	// generate the schema based on the string of schema
	List<StructField> fields = new ArrayList<>();
	
	// add id column type
	if( containsID )
		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, 
				DataTypes.DoubleType, true));
	
	// add remaining types
	int col = 1;
	for (ValueType schema : fschema) {
		DataType dt = null;
		switch(schema) {
			case STRING:  dt = DataTypes.StringType; break;
			case FP64:  dt = DataTypes.DoubleType; break;
			case INT64:     dt = DataTypes.LongType; break;
			case BOOLEAN: dt = DataTypes.BooleanType; break;
			default:      dt = DataTypes.StringType;
				LOG.warn("Using default type String for " + schema.toString());
		}
		fields.add(DataTypes.createStructField("C"+col++, dt, true));
	}
	
	return DataTypes.createStructType(fields);
}
 
Example 11
Source File: JavaBucketizerExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaBucketizerExample")
    .getOrCreate();

  // $example on$
  double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};

  List<Row> data = Arrays.asList(
    RowFactory.create(-999.9),
    RowFactory.create(-0.5),
    RowFactory.create(-0.3),
    RowFactory.create(0.0),
    RowFactory.create(0.2),
    RowFactory.create(999.9)
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
  });
  Dataset<Row> dataFrame = spark.createDataFrame(data, schema);

  Bucketizer bucketizer = new Bucketizer()
    .setInputCol("features")
    .setOutputCol("bucketedFeatures")
    .setSplits(splits);

  // Transform original data into its bucket index.
  Dataset<Row> bucketedData = bucketizer.transform(dataFrame);

  System.out.println("Bucketizer output with " + (bucketizer.getSplits().length-1) + " buckets");
  bucketedData.show();
  // $example off$

  spark.stop();
}
 
Example 12
Source File: DataFrames.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Convert the DataVec sequence schema to a StructType for Spark, for example for use in
 * {@link #toDataFrameSequence(Schema, JavaRDD)}}
 * <b>Note</b>: as per {@link #toDataFrameSequence(Schema, JavaRDD)}}, the StructType has two additional columns added to it:<br>
 * - Column 0: Sequence UUID (name: {@link #SEQUENCE_UUID_COLUMN}) - a UUID for the original sequence<br>
 * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
 * of this record in the original time series.<br>
 * These two columns are required if the data is to be converted back into a sequence at a later point, for example
 * using {@link #toRecordsSequence(Dataset<Row>)}
 *
 * @param schema Schema to convert
 * @return StructType for the schema
 */
public static StructType fromSchemaSequence(Schema schema) {
    StructField[] structFields = new StructField[schema.numColumns() + 2];

    structFields[0] = new StructField(SEQUENCE_UUID_COLUMN, DataTypes.StringType, false, Metadata.empty());
    structFields[1] = new StructField(SEQUENCE_INDEX_COLUMN, DataTypes.IntegerType, false, Metadata.empty());

    for (int i = 0; i < schema.numColumns(); i++) {
        switch (schema.getColumnTypes().get(i)) {
            case Double:
                structFields[i + 2] =
                                new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                break;
            case Integer:
                structFields[i + 2] =
                                new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                break;
            case Long:
                structFields[i + 2] =
                                new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                break;
            case Float:
                structFields[i + 2] =
                                new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                break;
            default:
                throw new IllegalStateException(
                                "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
        }
    }
    return new StructType(structFields);
}
 
Example 13
Source File: VectorAssemblerBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testVectorAssembler() {
    // prepare data

    JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList(
            RowFactory.create(0d, 1d, new DenseVector(new double[]{2d, 3d})),
            RowFactory.create(1d, 2d, new DenseVector(new double[]{3d, 4d})),
            RowFactory.create(2d, 3d, new DenseVector(new double[]{4d, 5d})),
            RowFactory.create(3d, 4d, new DenseVector(new double[]{5d, 6d})),
            RowFactory.create(4d, 5d, new DenseVector(new double[]{6d, 7d}))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("vector1", new VectorUDT(), false, Metadata.empty())
    });

    DataFrame df = sqlContext.createDataFrame(jrdd, schema);
    VectorAssembler vectorAssembler = new VectorAssembler()
            .setInputCols(new String[]{"value1", "vector1"})
            .setOutputCol("feature");


    //Export this model
    byte[] exportedModel = ModelExporter.export(vectorAssembler, null);

    String exportedModelJson = new String(exportedModel);
    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
    //compare predictions
    Row[] sparkOutput = vectorAssembler.transform(df).orderBy("id").select("id", "value1", "vector1", "feature").collect();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<>();
        data.put(vectorAssembler.getInputCols()[0], row.get(1));
        data.put(vectorAssembler.getInputCols()[1], ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(vectorAssembler.getOutputCol());
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 14
Source File: ChiSqSelectorBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testChiSqSelector() {
    // prepare data

    List<Row> inputData = Arrays.asList(
            RowFactory.create(0d, 0d, new DenseVector(new double[]{8d, 7d, 0d})),
            RowFactory.create(1d, 1d, new DenseVector(new double[]{0d, 9d, 6d})),
            RowFactory.create(2d, 1d, new DenseVector(new double[]{0.0d, 9.0d, 8.0d})),
            RowFactory.create(3d, 2d, new DenseVector(new double[]{8.0d, 9.0d, 5.0d}))
    );

    double[] preFilteredData = {0.0d, 6.0d, 8.0d, 5.0d};

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    Dataset<Row> df = spark.createDataFrame(inputData, schema);
    ChiSqSelector chiSqSelector = new ChiSqSelector();
    chiSqSelector.setNumTopFeatures(1);
    chiSqSelector.setFeaturesCol("features");
    chiSqSelector.setLabelCol("label");
    chiSqSelector.setOutputCol("output");

    ChiSqSelectorModel chiSqSelectorModel = chiSqSelector.fit(df);

    //Export this model
    byte[] exportedModel = ModelExporter.export(chiSqSelectorModel);

    String exportedModelJson = new String(exportedModel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    List<Row> sparkOutput = chiSqSelectorModel.transform(df).orderBy("id").select("id", "label", "features", "output").collectAsList();
    for (Row row : sparkOutput) {
        Map<String, Object> data = new HashMap<>();
        data.put(chiSqSelectorModel.getFeaturesCol(), ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(chiSqSelectorModel.getOutputCol());
        System.out.println(Arrays.toString(output));
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 15
Source File: BucketizerBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void bucketizerTest() {
    double[] validData = {-0.5, -0.3, 0.0, 0.2};
    double[] expectedBuckets = {0.0, 0.0, 1.0, 1.0};
    double[] splits = {-0.5, 0.0, 0.5};

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
            new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
    });
    List<Row> trainingData = Arrays.asList(
            cr(0, validData[0]),
            cr(1, validData[1]),
            cr(2, validData[2]),
            cr(3, validData[3]));

    Dataset<Row> df = spark.createDataFrame(trainingData, schema);

    Bucketizer sparkModel = new Bucketizer()
            .setInputCol("feature")
            .setOutputCol("result")
            .setSplits(splits);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkModel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    List<Row> sparkOutput = sparkModel.transform(df).orderBy("id").select("id", "feature", "result").collectAsList();

    for (Row r : sparkOutput) {
        double input = r.getDouble(1);
        double sparkOp = r.getDouble(2);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put(sparkModel.getInputCol(), input);
        transformer.transform(data);
        double transformedInput = (double) data.get(sparkModel.getOutputCol());

        assertTrue((transformedInput >= 0) && (transformedInput <= 1));
        assertEquals(transformedInput, sparkOp, 0.01);
        assertEquals(transformedInput, expectedBuckets[r.getInt(0)], 0.01);
    }
}
 
Example 16
Source File: ConfigurationDataTypes.java    From envelope with Apache License 2.0 4 votes vote down vote up
public static DataType getSparkDataType(String typeString) {
  DataType type;

  String prec_scale_regex_groups = "\\s*(decimal)\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)\\s*";
  Pattern prec_scale_regex_pattern = Pattern.compile(prec_scale_regex_groups);
  Matcher prec_scale_regex_matcher = prec_scale_regex_pattern.matcher(typeString);

  if (prec_scale_regex_matcher.matches()) {
    int precision = Integer.parseInt(prec_scale_regex_matcher.group(2)); 
    int scale = Integer.parseInt(prec_scale_regex_matcher.group(3)); 
    type = DataTypes.createDecimalType(precision, scale);
  }
  else {
    switch (typeString) {
      case DECIMAL:
        type = DataTypes.createDecimalType();
        break;
      case STRING:
        type = DataTypes.StringType;
        break;
      case FLOAT:
        type = DataTypes.FloatType;
        break;
      case DOUBLE:
        type = DataTypes.DoubleType;
        break;
      case BYTE:
        type = DataTypes.ByteType;
        break;
      case SHORT:
        type = DataTypes.ShortType;
        break;
      case INT:
        type = DataTypes.IntegerType;
        break;
      case LONG:
        type = DataTypes.LongType;
        break;
      case BOOLEAN:
        type = DataTypes.BooleanType;
        break;
      case BINARY:
        type = DataTypes.BinaryType;
        break;
      case DATE:
        type = DataTypes.DateType;
        break;
      case TIMESTAMP:
        type = DataTypes.TimestampType;
        break;
      default:
        throw new RuntimeException("Unsupported or unrecognized field type: " + typeString);
    } 
  }

  return type;
}
 
Example 17
Source File: AvroUtils.java    From envelope with Apache License 2.0 4 votes vote down vote up
/**
 * Convert Avro Types into their associated DataType.
 *
 * @param schemaType Avro Schema.Type
 * @return DataType representation
 */
public static DataType dataTypeFor(Schema schemaType) {
  LOG.trace("Converting Schema[{}] to DataType", schemaType);

  // Unwrap "optional" unions to the base type
  boolean isOptional = isNullable(schemaType);

  if (isOptional) {
    // if only 2 items in the union, then "unwrap," otherwise, it's a full union and should be rendered as such
    if (schemaType.getTypes().size() == 2) {
      LOG.trace("Unwrapping simple 'optional' union for {}", schemaType);
      for (Schema s : schemaType.getTypes()) {
        if (s.getType().equals(NULL)) {
          continue;
        }
        // Unwrap
        schemaType = s;
        break;
      }
    }
  }

  // Convert supported LogicalTypes
  if (null != schemaType.getLogicalType()) {
    LogicalType logicalType = schemaType.getLogicalType();
    switch (logicalType.getName()) {
      case "date" :
        return DataTypes.DateType;
      case "timestamp-millis" :
        return DataTypes.TimestampType;
      case "decimal" :
        LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType;
        return DataTypes.createDecimalType(decimal.getPrecision(), decimal.getScale());
      default:
        // Pass-thru
        LOG.warn("Unsupported LogicalType[{}], continuing with underlying base type", logicalType.getName());
    }
  }

  switch (schemaType.getType()) {
    case RECORD:
      // StructType
      List<StructField> structFieldList = Lists.newArrayListWithCapacity(schemaType.getFields().size());
      for (Field f : schemaType.getFields()) {
        structFieldList.add(DataTypes.createStructField(f.name(), dataTypeFor(f.schema()), isNullable(f.schema())));
      }
      return DataTypes.createStructType(structFieldList);
    case ARRAY:
      Schema elementType = schemaType.getElementType();
      return DataTypes.createArrayType(dataTypeFor(elementType), isNullable(elementType));
    case MAP:
      Schema valueType = schemaType.getValueType();
      return DataTypes.createMapType(DataTypes.StringType, dataTypeFor(valueType), isNullable(valueType));
    case UNION:
      // StructType of members
      List<StructField> unionFieldList = Lists.newArrayListWithCapacity(schemaType.getTypes().size());
      int m = 0;
      for (Schema u : schemaType.getTypes()) {
        unionFieldList.add(DataTypes.createStructField("member" + m++, dataTypeFor(u), isNullable(u)));
      }
      return DataTypes.createStructType(unionFieldList);
    case FIXED:
    case BYTES:
      return DataTypes.BinaryType;
    case ENUM:
    case STRING:
      return DataTypes.StringType;
    case INT:
      return DataTypes.IntegerType;
    case LONG:
      return DataTypes.LongType;
    case FLOAT:
      return DataTypes.FloatType;
    case DOUBLE:
      return DataTypes.DoubleType;
    case BOOLEAN:
      return DataTypes.BooleanType;
    case NULL:
      return DataTypes.NullType;
    default:
      throw new RuntimeException(String.format("Unrecognized or unsupported Avro Type conversion: %s", schemaType));
  }
}
 
Example 18
Source File: ChiSqSelectorBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testChiSqSelector() {
    // prepare data

    JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList(
            RowFactory.create(0d, 0d, new DenseVector(new double[]{8d, 7d, 0d})),
            RowFactory.create(1d, 1d, new DenseVector(new double[]{0d, 9d, 6d})),
            RowFactory.create(2d, 1d, new DenseVector(new double[]{0.0d, 9.0d, 8.0d})),
            RowFactory.create(3d, 2d, new DenseVector(new double[]{8.0d, 9.0d, 5.0d}))
    ));

    double[] preFilteredData = {0.0d, 6.0d, 8.0d, 5.0d};

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    DataFrame df = sqlContext.createDataFrame(jrdd, schema);
    ChiSqSelector chiSqSelector = new ChiSqSelector();
    chiSqSelector.setNumTopFeatures(1);
    chiSqSelector.setFeaturesCol("features");
    chiSqSelector.setLabelCol("label");
    chiSqSelector.setOutputCol("output");

    ChiSqSelectorModel chiSqSelectorModel = chiSqSelector.fit(df);

    //Export this model
    byte[] exportedModel = ModelExporter.export(chiSqSelectorModel, null);

    String exportedModelJson = new String(exportedModel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    Row[] sparkOutput = chiSqSelectorModel.transform(df).orderBy("id").select("id", "label", "features", "output").collect();
    for (Row row : sparkOutput) {
        Map<String, Object> data = new HashMap<>();
        data.put(chiSqSelectorModel.getFeaturesCol(), ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(chiSqSelectorModel.getOutputCol());
        System.out.println(Arrays.toString(output));
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 19
Source File: OneHotEncoderBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testOneHotEncoding() {
    // prepare data
    JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList(
            RowFactory.create(0d, "a"),
            RowFactory.create(1d, "b"),
            RowFactory.create(2d, "c"),
            RowFactory.create(3d, "a"),
            RowFactory.create(4d, "a"),
            RowFactory.create(5d, "c")
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("category", DataTypes.StringType, false, Metadata.empty())
    });

    DataFrame df = sqlContext.createDataFrame(jrdd, schema);
    StringIndexerModel indexer = new StringIndexer()
            .setInputCol("category")
            .setOutputCol("categoryIndex")
            .fit(df);
    DataFrame indexed = indexer.transform(df);

    OneHotEncoder sparkModel = new OneHotEncoder()
            .setInputCol("categoryIndex")
            .setOutputCol("categoryVec");

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkModel, indexed);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    Row[] sparkOutput = sparkModel.transform(indexed).orderBy("id").select("id", "categoryIndex", "categoryVec").collect();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<String, Object>();
        data.put(sparkModel.getInputCol(), row.getDouble(1));
        transformer.transform(data);
        double[] transformedOp = (double[]) data.get(sparkModel.getOutputCol());

        double[] sparkOp = ((Vector) row.get(2)).toArray();
        assertArrayEquals(transformedOp, sparkOp, EPSILON);
    }
}
 
Example 20
Source File: JavaUserDefinedUntypedAggregation.java    From incubator-nemo with Apache License 2.0 2 votes vote down vote up
/**
 * The data type of the returned value.
 *
 * @return double type.
 */
public DataType dataType() {
  return DataTypes.DoubleType;
}