Java Code Examples for org.datavec.api.transform.schema.Schema#getMetaData()

The following examples show how to use org.datavec.api.transform.schema.Schema#getMetaData() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DeriveColumnsFromTimeTransform.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public void setInputSchema(Schema inputSchema) {
    insertAfterIdx = inputSchema.getColumnNames().indexOf(insertAfter);
    if (insertAfterIdx == -1) {
        throw new IllegalStateException(
                        "Invalid schema/insert after column: input schema does not contain column \"" + insertAfter
                                        + "\"");
    }

    deriveFromIdx = inputSchema.getColumnNames().indexOf(columnName);
    if (deriveFromIdx == -1) {
        throw new IllegalStateException(
                        "Invalid source column: input schema does not contain column \"" + columnName + "\"");
    }

    this.inputSchema = inputSchema;

    if (!(inputSchema.getMetaData(columnName) instanceof TimeMetaData))
        throw new IllegalStateException("Invalid state: input column \"" + columnName
                        + "\" is not a time column. Is: " + inputSchema.getMetaData(columnName));
    TimeMetaData meta = (TimeMetaData) inputSchema.getMetaData(columnName);
    inputTimeZone = meta.getTimeZone();
}
 
Example 2
Source File: NDArrayColumnsMathOpTransform.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) {
    //Check types

    for (int i = 0; i < columns.length; i++) {
        if (inputSchema.getMetaData(columns[i]).getColumnType() != ColumnType.NDArray) {
            throw new RuntimeException("Column " + columns[i] + " is not an NDArray column");
        }
    }

    //Check shapes
    NDArrayMetaData meta = (NDArrayMetaData) inputSchema.getMetaData(columns[0]);
    for (int i = 1; i < columns.length; i++) {
        NDArrayMetaData meta2 = (NDArrayMetaData) inputSchema.getMetaData(columns[i]);
        if (!Arrays.equals(meta.getShape(), meta2.getShape())) {
            throw new UnsupportedOperationException(
                            "Cannot perform NDArray operation on columns with different shapes: " + "Columns \""
                                            + columns[0] + "\" and \"" + columns[i] + "\" have shapes: "
                                            + Arrays.toString(meta.getShape()) + " and "
                                            + Arrays.toString(meta2.getShape()));
        }
    }

    return new NDArrayMetaData(newColumnName, meta.getShape());
}
 
Example 3
Source File: CategoricalToOneHotTransform.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public void setInputSchema(Schema inputSchema) {
    super.setInputSchema(inputSchema);

    columnIdx = inputSchema.getIndexOfColumn(columnName);
    ColumnMetaData meta = inputSchema.getMetaData(columnName);
    if (!(meta instanceof CategoricalMetaData))
        throw new IllegalStateException("Cannot convert column \"" + columnName
                        + "\" from categorical to one-hot: column is not categorical (is: " + meta.getColumnType()
                        + ")");
    this.stateNames = ((CategoricalMetaData) meta).getStateNames();

    this.statesMap = new HashMap<>(stateNames.size());
    for (int i = 0; i < stateNames.size(); i++) {
        this.statesMap.put(stateNames.get(i), i);
    }
}
 
Example 4
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testStringToCategoricalTransform() {
    Schema schema = getSchema(ColumnType.String);

    Transform transform = new StringToCategoricalTransform("column", Arrays.asList("zero", "one", "two"));
    transform.setInputSchema(schema);
    Schema out = transform.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Categorical, out.getMetaData(0).getColumnType());
    CategoricalMetaData meta = (CategoricalMetaData) out.getMetaData(0);
    assertEquals(Arrays.asList("zero", "one", "two"), meta.getStateNames());

    assertEquals(Collections.singletonList((Writable) new Text("zero")),
            transform.map(Collections.singletonList((Writable) new Text("zero"))));
    assertEquals(Collections.singletonList((Writable) new Text("one")),
            transform.map(Collections.singletonList((Writable) new Text("one"))));
    assertEquals(Collections.singletonList((Writable) new Text("two")),
            transform.map(Collections.singletonList((Writable) new Text("two"))));
}
 
Example 5
Source File: TestTransforms.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCategoricalToInteger() {
    Schema schema = getSchema(ColumnType.Categorical, "zero", "one", "two");

    Transform transform = new CategoricalToIntegerTransform("column");
    transform.setInputSchema(schema);
    Schema out = transform.transform(schema);


    TestCase.assertEquals(ColumnType.Integer, out.getMetaData(0).getColumnType());
    IntegerMetaData meta = (IntegerMetaData) out.getMetaData(0);
    assertNotNull(meta.getMinAllowedValue());
    assertEquals(0, (int) meta.getMinAllowedValue());

    assertNotNull(meta.getMaxAllowedValue());
    assertEquals(2, (int) meta.getMaxAllowedValue());

    assertEquals(0, transform.map(Collections.singletonList((Writable) new Text("zero"))).get(0).toInt());
    assertEquals(1, transform.map(Collections.singletonList((Writable) new Text("one"))).get(0).toInt());
    assertEquals(2, transform.map(Collections.singletonList((Writable) new Text("two"))).get(0).toInt());
}
 
Example 6
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCategoricalToOneHotTransform() {
    Schema schema = getSchema(ColumnType.Categorical, "zero", "one", "two");

    Transform transform = new CategoricalToOneHotTransform("column");
    transform.setInputSchema(schema);
    Schema out = transform.transform(schema);

    assertEquals(3, out.getColumnMetaData().size());
    for (int i = 0; i < 3; i++) {
        TestCase.assertEquals(ColumnType.Integer, out.getMetaData(i).getColumnType());
        IntegerMetaData meta = (IntegerMetaData) out.getMetaData(i);
        assertNotNull(meta.getMinAllowedValue());
        assertEquals(0, (int) meta.getMinAllowedValue());

        assertNotNull(meta.getMaxAllowedValue());
        assertEquals(1, (int) meta.getMaxAllowedValue());
    }

    assertEquals(Arrays.asList(new IntWritable(1), new IntWritable(0), new IntWritable(0)),
            transform.map(Collections.singletonList((Writable) new Text("zero"))));
    assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(0)),
            transform.map(Collections.singletonList((Writable) new Text("one"))));
    assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(0), new IntWritable(1)),
            transform.map(Collections.singletonList((Writable) new Text("two"))));
}
 
Example 7
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testDoubleMathOpTransform() {
    Schema schema = new Schema.Builder().addColumnDouble("column", -1.0, 1.0).build();

    Transform transform = new DoubleMathOpTransform("column", MathOp.Multiply, 5.0);
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);
    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Double, out.getType(0));
    DoubleMetaData meta = (DoubleMetaData) out.getMetaData(0);
    assertEquals(-5.0, meta.getMinAllowedValue(), 1e-6);
    assertEquals(5.0, meta.getMaxAllowedValue(), 1e-6);

    assertEquals(Collections.singletonList((Writable) new DoubleWritable(-5)),
            transform.map(Collections.singletonList((Writable) new DoubleWritable(-1))));
    assertEquals(Collections.singletonList((Writable) new DoubleWritable(0)),
            transform.map(Collections.singletonList((Writable) new DoubleWritable(0))));
    assertEquals(Collections.singletonList((Writable) new DoubleWritable(5)),
            transform.map(Collections.singletonList((Writable) new DoubleWritable(1))));
}
 
Example 8
Source File: TestTransforms.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testStringToCategoricalTransform() {
    Schema schema = getSchema(ColumnType.String);

    Transform transform = new StringToCategoricalTransform("column", Arrays.asList("zero", "one", "two"));
    transform.setInputSchema(schema);
    Schema out = transform.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Categorical, out.getMetaData(0).getColumnType());
    CategoricalMetaData meta = (CategoricalMetaData) out.getMetaData(0);
    assertEquals(Arrays.asList("zero", "one", "two"), meta.getStateNames());

    assertEquals(Collections.singletonList((Writable) new Text("zero")),
            transform.map(Collections.singletonList((Writable) new Text("zero"))));
    assertEquals(Collections.singletonList((Writable) new Text("one")),
            transform.map(Collections.singletonList((Writable) new Text("one"))));
    assertEquals(Collections.singletonList((Writable) new Text("two")),
            transform.map(Collections.singletonList((Writable) new Text("two"))));
}
 
Example 9
Source File: DeriveColumnsFromTimeTransform.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void setInputSchema(Schema inputSchema) {
    insertAfterIdx = inputSchema.getColumnNames().indexOf(insertAfter);
    if (insertAfterIdx == -1) {
        throw new IllegalStateException(
                        "Invalid schema/insert after column: input schema does not contain column \"" + insertAfter
                                        + "\"");
    }

    deriveFromIdx = inputSchema.getColumnNames().indexOf(columnName);
    if (deriveFromIdx == -1) {
        throw new IllegalStateException(
                        "Invalid source column: input schema does not contain column \"" + columnName + "\"");
    }

    this.inputSchema = inputSchema;

    if (!(inputSchema.getMetaData(columnName) instanceof TimeMetaData))
        throw new IllegalStateException("Invalid state: input column \"" + columnName
                        + "\" is not a time column. Is: " + inputSchema.getMetaData(columnName));
    TimeMetaData meta = (TimeMetaData) inputSchema.getMetaData(columnName);
    inputTimeZone = meta.getTimeZone();
}
 
Example 10
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLongMathOpTransform() {
    Schema schema = new Schema.Builder().addColumnLong("column", -1L, 1L).build();

    Transform transform = new LongMathOpTransform("column", MathOp.Multiply, 5);
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);
    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Long, out.getType(0));
    LongMetaData meta = (LongMetaData) out.getMetaData(0);
    assertEquals(-5, (long) meta.getMinAllowedValue());
    assertEquals(5, (long) meta.getMaxAllowedValue());

    assertEquals(Collections.singletonList((Writable) new LongWritable(-5)),
            transform.map(Collections.singletonList((Writable) new LongWritable(-1))));
    assertEquals(Collections.singletonList((Writable) new LongWritable(0)),
            transform.map(Collections.singletonList((Writable) new LongWritable(0))));
    assertEquals(Collections.singletonList((Writable) new LongWritable(5)),
            transform.map(Collections.singletonList((Writable) new LongWritable(1))));
}
 
Example 11
Source File: TestTransforms.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testSubtractMeanNormalizer() {
    Schema schema = getSchema(ColumnType.Double);

    double mu = 1.0;

    Transform transform = new SubtractMeanNormalizer("column", mu);
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Double, out.getMetaData(0).getColumnType());
    DoubleMetaData meta = (DoubleMetaData) out.getMetaData(0);
    assertNull(meta.getMinAllowedValue());
    assertNull(meta.getMaxAllowedValue());


    assertEquals(0.0, transform.map(Collections.singletonList((Writable) new DoubleWritable(mu))).get(0).toDouble(),
            1e-6);
    assertEquals(10 - mu,
            transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(),
            1e-6);
}
 
Example 12
Source File: TestTransforms.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testLog2Normalizer() {
    Schema schema = getSchema(ColumnType.Double);

    double mu = 2.0;
    double min = 1.0;
    double scale = 0.5;

    Transform transform = new Log2Normalizer("column", mu, min, scale);
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Double, out.getMetaData(0).getColumnType());
    DoubleMetaData meta = (DoubleMetaData) out.getMetaData(0);
    assertNotNull(meta.getMinAllowedValue());
    assertEquals(0, meta.getMinAllowedValue(), 1e-6);
    assertNull(meta.getMaxAllowedValue());

    double loge2 = Math.log(2);
    assertEquals(0.0,
            transform.map(Collections.singletonList((Writable) new DoubleWritable(min))).get(0).toDouble(),
            1e-6);
    double d = scale * Math.log((10 - min) / (mu - min) + 1) / loge2;
    assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(),
            1e-6);
    d = scale * Math.log((3 - min) / (mu - min) + 1) / loge2;
    assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(3))).get(0).toDouble(),
            1e-6);
}
 
Example 13
Source File: TestTransforms.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testTextToCharacterIndexTransform(){

    Schema s = new Schema.Builder().addColumnString("col").addColumnDouble("d").build();

    List<List<Writable>> inSeq = Arrays.asList(
            Arrays.<Writable>asList(new Text("text"), new DoubleWritable(1.0)),
            Arrays.<Writable>asList(new Text("ab"), new DoubleWritable(2.0)));

    Map<Character,Integer> map = new HashMap<>();
    map.put('a', 0);
    map.put('b', 1);
    map.put('e', 2);
    map.put('t', 3);
    map.put('x', 4);

    List<List<Writable>> exp = Arrays.asList(
            Arrays.<Writable>asList(new IntWritable(3), new DoubleWritable(1.0)),
            Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(1.0)),
            Arrays.<Writable>asList(new IntWritable(4), new DoubleWritable(1.0)),
            Arrays.<Writable>asList(new IntWritable(3), new DoubleWritable(1.0)),
            Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2.0)),
            Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(2.0)));

    Transform t = new TextToCharacterIndexTransform("col", "newName", map, false);
    t.setInputSchema(s);

    Schema outputSchema = t.transform(s);
    assertEquals(2, outputSchema.getColumnNames().size());
    assertEquals(ColumnType.Integer, outputSchema.getType(0));
    assertEquals(ColumnType.Double, outputSchema.getType(1));

    IntegerMetaData intMetadata = (IntegerMetaData)outputSchema.getMetaData(0);
    assertEquals(0, (int)intMetadata.getMinAllowedValue());
    assertEquals(4, (int)intMetadata.getMaxAllowedValue());

    List<List<Writable>> out = t.mapSequence(inSeq);
    assertEquals(exp, out);
}
 
Example 14
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testLog2Normalizer() {
    Schema schema = getSchema(ColumnType.Double);

    double mu = 2.0;
    double min = 1.0;
    double scale = 0.5;

    Transform transform = new Log2Normalizer("column", mu, min, scale);
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Double, out.getMetaData(0).getColumnType());
    DoubleMetaData meta = (DoubleMetaData) out.getMetaData(0);
    assertNotNull(meta.getMinAllowedValue());
    assertEquals(0, meta.getMinAllowedValue(), 1e-6);
    assertNull(meta.getMaxAllowedValue());

    double loge2 = Math.log(2);
    assertEquals(0.0,
            transform.map(Collections.singletonList((Writable) new DoubleWritable(min))).get(0).toDouble(),
            1e-6);
    double d = scale * Math.log((10 - min) / (mu - min) + 1) / loge2;
    assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(),
            1e-6);
    d = scale * Math.log((3 - min) / (mu - min) + 1) / loge2;
    assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(3))).get(0).toDouble(),
            1e-6);
}
 
Example 15
Source File: TestTransforms.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testStringListToCategoricalSetTransform() {
    //Idea: String list to a set of categories... "a,c" for categories {a,b,c} -> "true","false","true"

    Schema schema = getSchema(ColumnType.String);

    Transform transform = new StringListToCategoricalSetTransform("column", Arrays.asList("a", "b", "c"),
            Arrays.asList("a", "b", "c"), ",");
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);
    assertEquals(3, out.getColumnMetaData().size());
    for (int i = 0; i < 3; i++) {
        TestCase.assertEquals(ColumnType.Categorical, out.getType(i));
        CategoricalMetaData meta = (CategoricalMetaData) out.getMetaData(i);
        assertEquals(Arrays.asList("true", "false"), meta.getStateNames());
    }

    assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text(""))));
    assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text("a"))));
    assertEquals(Arrays.asList(new Text("false"), new Text("true"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text("b"))));
    assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("c"))));
    assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("a,c"))));
    assertEquals(Arrays.asList(new Text("true"), new Text("true"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("a,b,c"))));
}
 
Example 16
Source File: TestTransforms.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testDoubleMinMaxNormalizerTransform() {
    Schema schema = getSchema(ColumnType.Double);

    Transform transform = new MinMaxNormalizer("column", 0, 100);
    Transform transform2 = new MinMaxNormalizer("column", 0, 100, -1, 1);
    transform.setInputSchema(schema);
    transform2.setInputSchema(schema);

    Schema out = transform.transform(schema);
    Schema out2 = transform2.transform(schema);

    assertEquals(1, out.getColumnMetaData().size());
    TestCase.assertEquals(ColumnType.Double, out.getMetaData(0).getColumnType());
    DoubleMetaData meta = (DoubleMetaData) out.getMetaData(0);
    DoubleMetaData meta2 = (DoubleMetaData) out2.getMetaData(0);
    assertEquals(0, meta.getMinAllowedValue(), 1e-6);
    assertEquals(1, meta.getMaxAllowedValue(), 1e-6);
    assertEquals(-1, meta2.getMinAllowedValue(), 1e-6);
    assertEquals(1, meta2.getMaxAllowedValue(), 1e-6);


    assertEquals(0.0, transform.map(Collections.singletonList((Writable) new DoubleWritable(0))).get(0).toDouble(),
            1e-6);
    assertEquals(1.0,
            transform.map(Collections.singletonList((Writable) new DoubleWritable(100))).get(0).toDouble(),
            1e-6);
    assertEquals(0.5, transform.map(Collections.singletonList((Writable) new DoubleWritable(50))).get(0).toDouble(),
            1e-6);

    assertEquals(-1.0,
            transform2.map(Collections.singletonList((Writable) new DoubleWritable(0))).get(0).toDouble(),
            1e-6);
    assertEquals(1.0,
            transform2.map(Collections.singletonList((Writable) new DoubleWritable(100))).get(0).toDouble(),
            1e-6);
    assertEquals(0.0,
            transform2.map(Collections.singletonList((Writable) new DoubleWritable(50))).get(0).toDouble(),
            1e-6);
}
 
Example 17
Source File: IntegerToOneHotTransform.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public void setInputSchema(Schema inputSchema) {
    super.setInputSchema(inputSchema);

    columnIdx = inputSchema.getIndexOfColumn(columnName);
    ColumnMetaData meta = inputSchema.getMetaData(columnName);
    if (!(meta instanceof IntegerMetaData))
        throw new IllegalStateException("Cannot convert column \"" + columnName
                        + "\" from integer to one-hot: column is not integer (is: " + meta.getColumnType() + ")");
}
 
Example 18
Source File: TestTransforms.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStringListToCategoricalSetTransform() {
    //Idea: String list to a set of categories... "a,c" for categories {a,b,c} -> "true","false","true"

    Schema schema = getSchema(ColumnType.String);

    Transform transform = new StringListToCategoricalSetTransform("column", Arrays.asList("a", "b", "c"),
            Arrays.asList("a", "b", "c"), ",");
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);
    assertEquals(3, out.getColumnMetaData().size());
    for (int i = 0; i < 3; i++) {
        TestCase.assertEquals(ColumnType.Categorical, out.getType(i));
        CategoricalMetaData meta = (CategoricalMetaData) out.getMetaData(i);
        assertEquals(Arrays.asList("true", "false"), meta.getStateNames());
    }

    assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text(""))));
    assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text("a"))));
    assertEquals(Arrays.asList(new Text("false"), new Text("true"), new Text("false")),
            transform.map(Collections.singletonList((Writable) new Text("b"))));
    assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("c"))));
    assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("a,c"))));
    assertEquals(Arrays.asList(new Text("true"), new Text("true"), new Text("true")),
            transform.map(Collections.singletonList((Writable) new Text("a,b,c"))));
}
 
Example 19
Source File: PivotTransform.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Schema transform(Schema inputSchema) {
    if (!inputSchema.hasColumn(keyColumn) || !inputSchema.hasColumn(valueColumn)) {
        throw new UnsupportedOperationException("Key or value column not found: " + keyColumn + ", " + valueColumn
                        + " in " + inputSchema.getColumnNames());
    }

    List<String> origNames = inputSchema.getColumnNames();
    List<ColumnMetaData> origMeta = inputSchema.getColumnMetaData();

    int i = 0;
    Iterator<String> namesIter = origNames.iterator();
    Iterator<ColumnMetaData> typesIter = origMeta.iterator();

    List<ColumnMetaData> newMeta = new ArrayList<>(inputSchema.numColumns());

    int idxKey = inputSchema.getIndexOfColumn(keyColumn);
    int idxValue = inputSchema.getIndexOfColumn(valueColumn);

    ColumnMetaData valueMeta = inputSchema.getMetaData(idxValue);

    while (namesIter.hasNext()) {
        String s = namesIter.next();
        ColumnMetaData t = typesIter.next();

        if (i == idxKey) {
            //Convert this to a set of separate columns
            List<String> stateNames = ((CategoricalMetaData) inputSchema.getMetaData(idxKey)).getStateNames();
            for (String stateName : stateNames) {
                String newName = s + "[" + stateName + "]";

                ColumnMetaData newValueMeta = valueMeta.clone();
                newValueMeta.setName(newName);

                newMeta.add(newValueMeta);
            }
        } else if (i == idxValue) {
            i++;
            continue; //Skip column
        } else {
            newMeta.add(t);
        }
        i++;
    }

    //Infer the default value if necessary
    if (defaultValue == null) {
        switch (valueMeta.getColumnType()) {
            case String:
                defaultValue = new Text("");
                break;
            case Integer:
                defaultValue = new IntWritable(0);
                break;
            case Long:
                defaultValue = new LongWritable(0);
                break;
            case Double:
                defaultValue = new DoubleWritable(0.0);
                break;
            case Float:
                defaultValue = new FloatWritable(0.0f);
                break;
            case Categorical:
                defaultValue = new NullWritable();
                break;
            case Time:
                defaultValue = new LongWritable(0);
                break;
            case Bytes:
                throw new UnsupportedOperationException("Cannot infer default value for bytes");
            case Boolean:
                defaultValue = new Text("false");
                break;
            default:
                throw new UnsupportedOperationException(
                                "Cannot infer default value for " + valueMeta.getColumnType());
        }
    }

    return inputSchema.newSchema(newMeta);
}
 
Example 20
Source File: TestTransforms.java    From DataVec with Apache License 2.0 4 votes vote down vote up
@Test
public void testTextToTermIndexSequenceTransform(){

    Schema schema = new Schema.Builder()
            .addColumnString("ID")
            .addColumnString("TEXT")
            .addColumnDouble("FEATURE")
            .build();
    List<String> vocab = Arrays.asList("zero", "one", "two", "three");
    List<List<Writable>> inSeq = Arrays.asList(
            Arrays.<Writable>asList(new Text("a"), new Text("zero four two"), new DoubleWritable(4.2)),
            Arrays.<Writable>asList(new Text("b"), new Text("six one two four three five"), new DoubleWritable(87.9)));

    Schema expSchema = new Schema.Builder()
            .addColumnString("ID")
            .addColumnInteger("INDEXSEQ", 0, 3)
            .addColumnDouble("FEATURE")
            .build();
    List<List<Writable>> exp = Arrays.asList(
            Arrays.<Writable>asList(new Text("a"), new IntWritable(0), new DoubleWritable(4.2)),
            Arrays.<Writable>asList(new Text("a"), new IntWritable(2), new DoubleWritable(4.2)),
            Arrays.<Writable>asList(new Text("b"), new IntWritable(1), new DoubleWritable(87.9)),
            Arrays.<Writable>asList(new Text("b"), new IntWritable(2), new DoubleWritable(87.9)),
            Arrays.<Writable>asList(new Text("b"), new IntWritable(3), new DoubleWritable(87.9)));

    Transform t = new TextToTermIndexSequenceTransform("TEXT", "INDEXSEQ", vocab, " ", false);
    t.setInputSchema(schema);

    Schema outputSchema = t.transform(schema);
    assertEquals(expSchema.getColumnNames(), outputSchema.getColumnNames());
    assertEquals(expSchema.getColumnTypes(), outputSchema.getColumnTypes());
    assertEquals(expSchema, outputSchema);

    assertEquals(3, outputSchema.getColumnNames().size());
    assertEquals(ColumnType.String, outputSchema.getType(0));
    assertEquals(ColumnType.Integer, outputSchema.getType(1));
    assertEquals(ColumnType.Double, outputSchema.getType(2));

    IntegerMetaData intMetadata = (IntegerMetaData)outputSchema.getMetaData(1);
    assertEquals(0, (int)intMetadata.getMinAllowedValue());
    assertEquals(3, (int)intMetadata.getMaxAllowedValue());

    List<List<Writable>> out = t.mapSequence(inSeq);
    assertEquals(exp, out);

    TransformProcess tp = new TransformProcess.Builder(schema).transform(t).build();
    String json = tp.toJson();
    TransformProcess tp2 = TransformProcess.fromJson(json);
    assertEquals(tp, tp2);
}