org.datavec.api.writable.DoubleWritable Java Examples

The following examples show how to use org.datavec.api.writable.DoubleWritable. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ExcelRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
private List<Writable> rowToRecord(Row currRow) {
    if(numColumns < 0) {
        numColumns = currRow.getLastCellNum();
    }

    if(currRow.getLastCellNum() != numColumns) {
        throw new IllegalStateException("Invalid number of columns for row. First number of columns found was " + numColumns + " but row " + currRow.getRowNum() + " was " + currRow.getLastCellNum());
    }

    List<Writable> ret = new ArrayList<>(currRow.getLastCellNum());
    for(Cell cell: currRow) {
        String cellValue = dataFormatter.formatCellValue(cell);
        switch(cell.getCellTypeEnum()) {
            case BLANK: ret.add(new Text("")); break;
            case STRING: ret.add(new Text("")); break;
            case BOOLEAN: ret.add(new BooleanWritable(Boolean.valueOf(cellValue))); break;
            case NUMERIC: ret.add(new DoubleWritable(Double.parseDouble(cellValue))); break;
            default: ret.add(new Text(cellValue));
        }
    }

    return ret;

}
 
Example #2
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRecordReaderDataSetIteratorConcat() {

    //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically.

    List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
                    new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5),
                    new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9),
                    new IntWritable(1));

    RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));

    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3);

    DataSet ds = iter.next();
    INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9});
    INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3});

    assertEquals(expF, ds.getFeatures());
    assertEquals(expL, ds.getLabels());
}
 
Example #3
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
                                            new IntWritable(2),
                                            new DoubleWritable(1.2));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #4
Source File: SVMLightRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testNextRecord() throws IOException, InterruptedException {
    SVMLightRecordReader rr = new SVMLightRecordReader();
    Configuration config = new Configuration();
    config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
    config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
    config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile()));

    Record record = rr.nextRecord();
    List<Writable> recordList = record.getRecord();
    assertEquals(new DoubleWritable(1.0), recordList.get(1));
    assertEquals(new DoubleWritable(3.0), recordList.get(5));
    assertEquals(new DoubleWritable(4.0), recordList.get(7));

    record = rr.nextRecord();
    recordList = record.getRecord();
    assertEquals(new DoubleWritable(0.1), recordList.get(0));
    assertEquals(new DoubleWritable(6.6), recordList.get(5));
    assertEquals(new DoubleWritable(80.0), recordList.get(7));
}
 
Example #5
Source File: TestGeoTransforms.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCoordinatesDistanceTransform() throws Exception {
    Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
                    .build();

    Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
    transform.setInputSchema(schema);

    Schema out = transform.transform(schema);
    assertEquals(4, out.numColumns());
    assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
    assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
                    out.getColumnTypes());

    assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
                    transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
    assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
                    new DoubleWritable(Math.sqrt(160))),
                    transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
                                    new Text("10|5"))));
}
 
Example #6
Source File: CSVSparkTransformTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testTransformerBatch() throws Exception {
    List<Writable> input = new ArrayList<>();
    input.add(new DoubleWritable(1.0));
    input.add(new DoubleWritable(2.0));

    Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
    List<Writable> output = new ArrayList<>();
    output.add(new Text("1.0"));
    output.add(new Text("2.0"));

    TransformProcess transformProcess =
            new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
    CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
    String[] values = new String[] {"1.0", "2.0"};
    SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
    BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
    for (int i = 0; i < 3; i++)
        batchCSVRecord.add(record);
    //data type is string, unable to convert
    BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
  /*  Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
    INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
    assertTrue(fromBase64.isMatrix());
    System.out.println("Base 64ed array " + fromBase64); */
}
 
Example #7
Source File: CSVSparkTransformTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testTransformer() throws Exception {
        List<Writable> input = new ArrayList<>();
        input.add(new DoubleWritable(1.0));
        input.add(new DoubleWritable(2.0));

        Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
        List<Writable> output = new ArrayList<>();
        output.add(new Text("1.0"));
        output.add(new Text("2.0"));

        TransformProcess transformProcess =
                new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
        CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
        String[] values = new String[] {"1.0", "2.0"};
        SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
        Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
        INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
        assertTrue(fromBase64.isVector());
//        System.out.println("Base 64ed array " + fromBase64);
    }
 
Example #8
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI() {

    Collection<Collection<Collection<Writable>>> c = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1)));
    c.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    c.add(seq2);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #9
Source File: TestTransformProcess.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testExecution(){

    Schema schema = new Schema.Builder()
            .addColumnsString("col")
            .addColumnsDouble("col2")
            .build();

    Map<Character,Integer> m = defaultCharIndex();
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .doubleMathOp("col2", MathOp.Add, 1.0)
            .build();

    List<Writable> in = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(2.0));
    List<Writable> exp = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(3.0));

    List<Writable> out = transformProcess.execute(in);
    assertEquals(exp, out);
}
 
Example #10
Source File: TestFilters.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testFilterNumColumns() {
    List<List<Writable>> list = new ArrayList<>();
    list.add(Collections.singletonList((Writable) new IntWritable(-1)));
    list.add(Collections.singletonList((Writable) new IntWritable(0)));
    list.add(Collections.singletonList((Writable) new IntWritable(2)));

    Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok
                    .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite
                    .build();
    Filter numColumns = new InvalidNumColumns(schema);
    for (int i = 0; i < list.size(); i++)
        assertTrue(numColumns.removeExample(list.get(i)));

    List<Writable> correct = Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2));
    assertFalse(numColumns.removeExample(correct));

}
 
Example #11
Source File: NDArrayHistogramCounter.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public HistogramCounter add(Writable w) {
    INDArray arr = ((NDArrayWritable) w).get();
    if (arr == null) {
        return this;
    }

    long length = arr.length();
    DoubleWritable dw = new DoubleWritable();
    for (int i = 0; i < length; i++) {
        dw.set(arr.getDouble(i));
        underlying.add(dw);
    }

    return this;
}
 
Example #12
Source File: CSVSparkTransformTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testTransformerBatch() throws Exception {
    List<Writable> input = new ArrayList<>();
    input.add(new DoubleWritable(1.0));
    input.add(new DoubleWritable(2.0));

    Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
    List<Writable> output = new ArrayList<>();
    output.add(new Text("1.0"));
    output.add(new Text("2.0"));

    TransformProcess transformProcess =
            new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
    CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
    String[] values = new String[] {"1.0", "2.0"};
    SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
    BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
    for (int i = 0; i < 3; i++)
        batchCSVRecord.add(record);
    //data type is string, unable to convert
    BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
  /*  Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
    INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
    assertTrue(fromBase64.isMatrix());
    System.out.println("Base 64ed array " + fromBase64); */
}
 
Example #13
Source File: TestTransformProcess.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testExecution(){

    Schema schema = new Schema.Builder()
            .addColumnsString("col")
            .addColumnsDouble("col2")
            .build();

    Map<Character,Integer> m = defaultCharIndex();
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .doubleMathOp("col2", MathOp.Add, 1.0)
            .build();

    List<Writable> in = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(2.0));
    List<Writable> exp = Arrays.<Writable>asList(new Text("Text"), new DoubleWritable(3.0));

    List<Writable> out = transformProcess.execute(in);
    assertEquals(exp, out);
}
 
Example #14
Source File: TestImageRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
private static Writable testLabel(String filename){
    switch(filename){
        case "0.jpg":
            return new DoubleWritable(0.0);
        case "1.png":
            return new DoubleWritable(1.0);
        case "2.jpg":
            return new DoubleWritable(2.0);
        case "A.jpg":
            return new DoubleWritable(10);
        case "B.png":
            return new DoubleWritable(11);
        case "C.jpg":
            return new DoubleWritable(12);
        default:
            throw new RuntimeException(filename);
    }
}
 
Example #15
Source File: VideoRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public List<List<Writable>> sequenceRecord() {
    File next = iter.next();
    invokeListeners(next);
    if (!next.isDirectory())
        return Collections.emptyList();
    File[] list = next.listFiles();
    List<List<Writable>> ret = new ArrayList<>();
    for (File f : list) {
        try {
            List<Writable> record = RecordConverter.toRecord(imageLoader.asRowVector(f));
            ret.add(record);
            if (appendLabel)
                record.add(new DoubleWritable(labels.indexOf(next.getName())));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

    }
    return ret;
}
 
Example #16
Source File: ExecutionTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecution() throws Exception {
    Schema schema = new Schema.Builder().addColumnInteger("col0")
            .addColumnString("col1").addColumnDouble("col2").build();

    Schema finalSchema = new Schema.Builder().addColumnInteger("col0")
            .addColumnInteger("col1").addColumnDouble("col2").build();
    String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
    TransformProcess tp = new TransformProcess.Builder(schema).transform(
            PythonTransform.builder().code(
                    "first = np.sin(first)\nsecond = np.cos(second)")
                    .outputSchema(finalSchema).build()
    ).build();
    List<List<Writable>> inputData = new ArrayList<>();
    inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));

    JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);

    List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());

    Collections.sort(out, new Comparator<List<Writable>>() {
        @Override
        public int compare(List<Writable> o1, List<Writable> o2) {
            return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
        }
    });

    List<List<Writable>> expected = new ArrayList<>();
    expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
    expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
    expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));

    assertEquals(expected, out);
}
 
Example #17
Source File: ExecutionTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testReductionGlobal() {

    List<List<Writable>> in = Arrays.asList(
            Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
            Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
    );

    JavaRDD<List<Writable>> inData = sc.parallelize(in);

    Schema s = new Schema.Builder()
            .addColumnString("textCol")
            .addColumnDouble("doubleCol")
            .build();

    TransformProcess tp = new TransformProcess.Builder(s)
            .reduce(new Reducer.Builder(ReduceOp.TakeFirst)
                    .takeFirstColumns("textCol")
                    .meanColumns("doubleCol").build())
            .build();

    JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);

    List<List<Writable>> out = outRdd.collect();

    List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));

    assertEquals(expOut, out);
}
 
Example #18
Source File: LibSvmRecordWriterTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #19
Source File: TestFilters.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testFilterInvalidValues() {

    List<List<Writable>> list = new ArrayList<>();
    list.add(Collections.singletonList((Writable) new IntWritable(-1)));
    list.add(Collections.singletonList((Writable) new IntWritable(0)));
    list.add(Collections.singletonList((Writable) new IntWritable(2)));

    Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok
                    .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite
                    .build();

    Filter filter = new FilterInvalidValues("intCol", "doubleCol");
    filter.setInputSchema(schema);

    //Test valid examples:
    assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(0))));
    assertFalse(filter.removeExample(asList((Writable) new IntWritable(10), new DoubleWritable(0))));
    assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(-100))));
    assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(100))));

    //Test invalid:
    assertTrue(filter.removeExample(asList((Writable) new IntWritable(-1), new DoubleWritable(0))));
    assertTrue(filter.removeExample(asList((Writable) new IntWritable(11), new DoubleWritable(0))));
    assertTrue(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(-101))));
    assertTrue(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(101))));
}
 
Example #20
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritables() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 13);
    arr3.putScalar(1, 14);
    arr3.putScalar(2, 15);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
                                        new NDArrayWritable(arr2),
                                        new IntWritable(2),
                                        new DoubleWritable(3),
                                        new NDArrayWritable(arr3),
                                        new IntWritable(4));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #21
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #22
Source File: TestWritablesToStringFunctions.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testWritablesToString() throws Exception {

    List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue"));
    String expected = l.get(0).toString() + "," + l.get(1).toString();

    assertEquals(expected, new WritablesToStringFunction(",").apply(l));
}
 
Example #23
Source File: MinMaxNormalizer.java    From DataVec with Apache License 2.0 5 votes vote down vote up
/**
 * Transform an object
 * in to another object
 *
 * @param input the record to transform
 * @return the transformed writable
 */
@Override
public Object map(Object input) {
    Number n = (Number) input;
    double val = n.doubleValue();
    if (Double.isNaN(val))
        return new DoubleWritable(0);
    return ratio * (val - min) + newMin;
}
 
Example #24
Source File: RecordUtils.java    From DataVec with Apache License 2.0 5 votes vote down vote up
public static List<Writable> toRecord(double[] record) {
    List<Writable> ret = new ArrayList<>(record.length);
    for (int i = 0; i < record.length; i++)
        ret.add(new DoubleWritable(record[i]));

    return ret;
}
 
Example #25
Source File: NDArrayDistanceTransform.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> map(List<Writable> writables) {
    int idxFirst = inputSchema.getIndexOfColumn(firstCol);
    int idxSecond = inputSchema.getIndexOfColumn(secondCol);

    INDArray arr1 = ((NDArrayWritable) writables.get(idxFirst)).get();
    INDArray arr2 = ((NDArrayWritable) writables.get(idxSecond)).get();

    double d;
    switch (distance) {
        case COSINE:
            d = Transforms.cosineSim(arr1, arr2);
            break;
        case EUCLIDEAN:
            d = Transforms.euclideanDistance(arr1, arr2);
            break;
        case MANHATTAN:
            d = Transforms.manhattanDistance(arr1, arr2);
            break;
        default:
            throw new UnsupportedOperationException("Unknown or not supported distance metric: " + distance);
    }

    List<Writable> out = new ArrayList<>(writables.size() + 1);
    out.addAll(writables);
    out.add(new DoubleWritable(d));

    return out;
}
 
Example #26
Source File: TestNDArrayToWritablesFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayToWritablesScalars() throws Exception {
    INDArray arr = Nd4j.arange(5);
    List<Writable> expected = new ArrayList<>();
    for (int i = 0; i < 5; i++)
        expected.add(new DoubleWritable(i));
    List<Writable> actual = new NDArrayToWritablesFunction().call(arr);
    assertEquals(expected, actual);
}
 
Example #27
Source File: TestCalculateSortedRank.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCalculateSortedRank() {

    List<List<Writable>> data = new ArrayList<>();
    data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0)));
    data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3)));
    data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2)));
    data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1)));

    JavaRDD<List<Writable>> rdd = sc.parallelize(data);

    Schema schema = new Schema.Builder().addColumnsString("TextCol").addColumnDouble("DoubleCol").build();

    TransformProcess tp = new TransformProcess.Builder(schema)
                    .calculateSortedRank("rank", "DoubleCol", new DoubleWritableComparator()).build();

    Schema outSchema = tp.getFinalSchema();
    assertEquals(3, outSchema.numColumns());
    assertEquals(Arrays.asList("TextCol", "DoubleCol", "rank"), outSchema.getColumnNames());
    assertEquals(Arrays.asList(ColumnType.String, ColumnType.Double, ColumnType.Long), outSchema.getColumnTypes());

    JavaRDD<List<Writable>> out = SparkTransformExecutor.execute(rdd, tp);

    List<List<Writable>> collected = out.collect();
    assertEquals(4, collected.size());
    for (int i = 0; i < 4; i++)
        assertEquals(3, collected.get(i).size());

    for (List<Writable> example : collected) {
        int exampleNum = example.get(0).toInt();
        int rank = example.get(2).toInt();
        assertEquals(exampleNum, rank);
    }
}
 
Example #28
Source File: Log2Normalizer.java    From DataVec with Apache License 2.0 5 votes vote down vote up
/**
 * Transform an object
 * in to another object
 *
 * @param input the record to transform
 * @return the transformed writable
 */
@Override
public Object map(Object input) {
    Number n = (Number) input;
    double val = n.doubleValue();
    if (Double.isNaN(val))
        return new DoubleWritable(0);
    return normMean(val);
}
 
Example #29
Source File: ParseDoubleTransform.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> map(List<Writable> writables) {
    List<Writable> transform = new ArrayList<>();
    for (Writable w : writables) {
        if (w instanceof Text) {
            transform.add(new DoubleWritable(w.toDouble()));
        } else {
            transform.add(w);
        }
    }
    return transform;
}
 
Example #30
Source File: ConvertToDouble.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public DoubleWritable map(Writable writable) {
    if(writable.getType() == WritableType.Double){
        return (DoubleWritable)writable;
    }
    return new DoubleWritable(writable.toDouble());
}