org.datavec.api.writable.IntWritable Java Examples

The following examples show how to use org.datavec.api.writable.IntWritable. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test(expected = NumberFormatException.class)
public void nonBinaryMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
            new IntWritable(1),
            new IntWritable(2));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #2
Source File: TextToCharacterIndexTransform.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) {
    if(writableMap == null){
        Map<Character,List<Writable>> m = new HashMap<>();
        for(Map.Entry<Character,Integer> entry : characterIndexMap.entrySet()){
            m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue())));
        }
        writableMap = m;
    }
    List<List<Writable>> out = new ArrayList<>();
    char[] cArr = currentStepValues.get(0).toString().toCharArray();
    for( char c : cArr ){
        List<Writable> w = writableMap.get(c);
        if(w == null ){
            if(exceptionOnUnknown){
                throw new IllegalStateException("Unknown character found in text: \"" + c + "\"");
            }
            continue;
        }

        out.add(w);
    }

    return out;
}
 
Example #3
Source File: ArrowConverterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testArrowBatchSet() {
    Schema.Builder schema = new Schema.Builder();
    List<String> single = new ArrayList<>();
    for(int i = 0; i < 2; i++) {
        schema.addColumnInteger(String.valueOf(i));
        single.add(String.valueOf(i));
    }

    List<List<Writable>> input = Arrays.asList(
            Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)),
            Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
    );

    List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
    ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
    List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
    writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5)));
    List<Writable> recordTest = writableRecordBatch.get(1);
    assertEquals(assertion,recordTest);
}
 
Example #4
Source File: TransformProcessRecordReaderTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTestSequence() {
    List<List<Writable>> sequence = new ArrayList<>();
    //First window:
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
                    new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
                    new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
                    new IntWritable(0)));

    Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
                    .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
    InMemorySequenceRecordReader inMemorySequenceRecordReader =
                    new InMemorySequenceRecordReader(Arrays.asList(sequence));
    TransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
                    new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
    List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
    assertEquals(2, next.get(0).size());

}
 
Example #5
Source File: ExcelRecordWriterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private Triple<String,Schema,List<List<Writable>>> records() {
    List<List<Writable>> list = new ArrayList<>();
    StringBuilder sb = new StringBuilder();
    int numColumns = 3;
    for (int i = 0; i < 10; i++) {
        List<Writable> temp = new ArrayList<>();
        for (int j = 0; j < numColumns; j++) {
            int v = 100 * i + j;
            temp.add(new IntWritable(v));
            sb.append(v);
            if (j < 2)
                sb.append(",");
            else if (i != 9)
                sb.append("\n");
        }
        list.add(temp);
    }


    Schema.Builder schemaBuilder = new Schema.Builder();
    for(int i = 0; i < numColumns; i++) {
        schemaBuilder.addColumnInteger(String.valueOf(i));
    }

    return Triple.of(sb.toString(),schemaBuilder.build(),list);
}
 
Example #6
Source File: ExcelRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
private Triple<String,Schema,List<List<Writable>>> records() {
    List<List<Writable>> list = new ArrayList<>();
    StringBuilder sb = new StringBuilder();
    int numColumns = 3;
    for (int i = 0; i < 10; i++) {
        List<Writable> temp = new ArrayList<>();
        for (int j = 0; j < numColumns; j++) {
            int v = 100 * i + j;
            temp.add(new IntWritable(v));
            sb.append(v);
            if (j < 2)
                sb.append(",");
            else if (i != 9)
                sb.append("\n");
        }
        list.add(temp);
    }


    Schema.Builder schemaBuilder = new Schema.Builder();
    for(int i = 0; i < numColumns; i++) {
        schemaBuilder.addColumnInteger(String.valueOf(i));
    }

    return Triple.of(sb.toString(),schemaBuilder.build(),list);
}
 
Example #7
Source File: CSVRecordReaderTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = NoSuchElementException.class)
public void testCsvSkipAllLines() throws IOException, InterruptedException {
    final int numLines = 4;
    final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
                    (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
    String header = ",one,two,three";
    List<String> lines = new ArrayList<>();
    for (int i = 0; i < numLines; i++)
        lines.add(Integer.toString(i) + header);
    File tempFile = File.createTempFile("csvSkipLines", ".csv");
    FileUtils.writeLines(tempFile, lines);

    CSVRecordReader rr = new CSVRecordReader(numLines, ',');
    rr.initialize(new FileSplit(tempFile));
    rr.reset();
    assertTrue(!rr.hasNext());
    rr.next();
}
 
Example #8
Source File: LibSvmRecordWriterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
                                            new IntWritable(2),
                                            new DoubleWritable(1.2));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #9
Source File: RecordMapperTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
private Triple<String,Schema,List<List<Writable>>> records() {
    List<List<Writable>> list = new ArrayList<>();
    StringBuilder sb = new StringBuilder();
    int numColumns = 3;
    for (int i = 0; i < 10; i++) {
        List<Writable> temp = new ArrayList<>();
        for (int j = 0; j < numColumns; j++) {
            int v = 100 * i + j;
            temp.add(new IntWritable(v));
            sb.append(v);
            if (j < 2)
                sb.append(",");
            else if (i != 9)
                sb.append("\n");
        }
        list.add(temp);
    }


    Schema.Builder schemaBuilder = new Schema.Builder();
    for(int i = 0; i < numColumns; i++) {
        schemaBuilder.addColumnInteger(String.valueOf(i));
    }

    return Triple.of(sb.toString(),schemaBuilder.build(),list);
}
 
Example #10
Source File: TextToTermIndexSequenceTransform.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) {
    if(writableMap == null){
        Map<String,List<Writable>> m = new HashMap<>();
        for(Map.Entry<String,Integer> entry : wordIndexMap.entrySet()) {
            m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue())));
        }
        writableMap = m;
    }
    List<List<Writable>> out = new ArrayList<>();
    String text = currentStepValues.get(0).toString();
    String[] tokens = text.split(this.delimiter);
    for(String token : tokens ){
        List<Writable> w = writableMap.get(token);
        if(w == null) {
            if(exceptionOnUnknown) {
                throw new IllegalStateException("Unknown token found in text: \"" + token + "\"");
            }
            continue;
        }
        out.add(w);
    }

    return out;
}
 
Example #11
Source File: TransformProcessRecordReaderTests.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTestSequence() {
    List<List<Writable>> sequence = new ArrayList<>();
    //First window:
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
                    new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
                    new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
                    new IntWritable(0)));

    Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
                    .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
    InMemorySequenceRecordReader inMemorySequenceRecordReader =
                    new InMemorySequenceRecordReader(Arrays.asList(sequence));
    TransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
                    new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
    List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
    assertEquals(2, next.get(0).size());

}
 
Example #12
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRMDSI() {

    Collection<Collection<Collection<Writable>>> c = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1)));
    c.add(seq1);

    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    c.add(seq2);

    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1);

    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #13
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
                                            new IntWritable(2),
                                            new DoubleWritable(1.2));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #14
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
            new IntWritable(2),
            new DoubleWritable(1.0));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #15
Source File: LongMetaData.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public boolean isValid(Writable writable) {
    long value;
    if (writable instanceof IntWritable || writable instanceof LongWritable) {
        value = writable.toLong();
    } else {
        try {
            value = Long.parseLong(writable.toString());
        } catch (NumberFormatException e) {
            return false;
        }
    }
    if (minAllowedValue != null && value < minAllowedValue)
        return false;
    if (maxAllowedValue != null && value > maxAllowedValue)
        return false;

    return true;
}
 
Example #16
Source File: TestFilters.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testFilterNumColumns() {
    List<List<Writable>> list = new ArrayList<>();
    list.add(Collections.singletonList((Writable) new IntWritable(-1)));
    list.add(Collections.singletonList((Writable) new IntWritable(0)));
    list.add(Collections.singletonList((Writable) new IntWritable(2)));

    Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok
                    .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite
                    .build();
    Filter numColumns = new InvalidNumColumns(schema);
    for (int i = 0; i < list.size(); i++)
        assertTrue(numColumns.removeExample(list.get(i)));

    List<Writable> correct = Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2));
    assertFalse(numColumns.removeExample(correct));

}
 
Example #17
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRecordReaderDataSetIteratorConcat() {

    //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically.

    List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
                    new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5),
                    new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9),
                    new IntWritable(1));

    RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));

    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3);

    DataSet ds = iter.next();
    INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9});
    INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3});

    assertEquals(expF, ds.getFeatures());
    assertEquals(expL, ds.getLabels());
}
 
Example #18
Source File: TextToCharacterIndexTransform.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) {
    if(writableMap == null){
        Map<Character,List<Writable>> m = new HashMap<>();
        for(Map.Entry<Character,Integer> entry : characterIndexMap.entrySet()){
            m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue())));
        }
        writableMap = m;
    }
    List<List<Writable>> out = new ArrayList<>();
    char[] cArr = currentStepValues.get(0).toString().toCharArray();
    for( char c : cArr ){
        List<Writable> w = writableMap.get(c);
        if(w == null ){
            if(exceptionOnUnknown){
                throw new IllegalStateException("Unknown character found in text: \"" + c + "\"");
            }
            continue;
        }

        out.add(w);
    }

    return out;
}
 
Example #19
Source File: LongMetaData.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public boolean isValid(Writable writable) {
    long value;
    if (writable instanceof IntWritable || writable instanceof LongWritable) {
        value = writable.toLong();
    } else {
        try {
            value = Long.parseLong(writable.toString());
        } catch (NumberFormatException e) {
            return false;
        }
    }
    if (minAllowedValue != null && value < minAllowedValue)
        return false;
    if (maxAllowedValue != null && value > maxAllowedValue)
        return false;

    return true;
}
 
Example #20
Source File: TestFilters.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testFilterNumColumns() {
    List<List<Writable>> list = new ArrayList<>();
    list.add(Collections.singletonList((Writable) new IntWritable(-1)));
    list.add(Collections.singletonList((Writable) new IntWritable(0)));
    list.add(Collections.singletonList((Writable) new IntWritable(2)));

    Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok
                    .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite
                    .build();
    Filter numColumns = new InvalidNumColumns(schema);
    for (int i = 0; i < list.size(); i++)
        assertTrue(numColumns.removeExample(list.get(i)));

    List<Writable> correct = Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2));
    assertFalse(numColumns.removeExample(correct));

}
 
Example #21
Source File: TestWritablesToNDArrayFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWritablesToNDArrayMixed() throws Exception {
    Nd4j.setDataType(DataType.FLOAT);
    List<Writable> l = new ArrayList<>();
    l.add(new IntWritable(0));
    l.add(new IntWritable(1));
    INDArray arr = Nd4j.arange(2, 5).reshape(1, 3);
    l.add(new NDArrayWritable(arr));
    l.add(new IntWritable(5));
    arr = Nd4j.arange(6, 9).reshape(1, 3);
    l.add(new NDArrayWritable(arr));
    l.add(new IntWritable(9));

    INDArray expected = Nd4j.arange(10).castTo(DataType.FLOAT).reshape(1, 10);
    assertEquals(expected, new WritablesToNDArrayFunction().apply(l));
}
 
Example #22
Source File: TestRecordReaders.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2)));

    CollectionRecordReader crr = new CollectionRecordReader(c);

    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);

    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (Exception e) {
        assertTrue(e.getMessage(), e.getMessage().contains("to one-hot"));
    }
}
 
Example #23
Source File: LocalTransformProcessRecordReaderTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTestSequence() {
    List<List<Writable>> sequence = new ArrayList<>();
    //First window:
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
            new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
            new IntWritable(0)));
    sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
            new IntWritable(0)));

    Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
            .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build();
    InMemorySequenceRecordReader inMemorySequenceRecordReader =
            new InMemorySequenceRecordReader(Arrays.asList(sequence));
    LocalTransformProcessSequenceRecordReader transformProcessSequenceRecordReader =
            new LocalTransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess);
    List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord();
    assertEquals(2, next.get(0).size());

}
 
Example #24
Source File: LibSvmRecordWriterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
    List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
            new IntWritable(2),
            new DoubleWritable(1.0));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }
}
 
Example #25
Source File: TestFilters.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testConditionFilter() {
    Schema schema = new Schema.Builder().addColumnInteger("column").build();

    Condition condition = new IntegerColumnCondition("column", ConditionOp.LessThan, 0);
    condition.setInputSchema(schema);

    Filter filter = new ConditionFilter(condition);

    assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10))));
    assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1))));
    assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0))));
    assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1))));
    assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10))));
}
 
Example #26
Source File: ExecutionTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecution() throws Exception {
    Schema schema = new Schema.Builder().addColumnInteger("col0")
            .addColumnString("col1").addColumnDouble("col2").build();

    Schema finalSchema = new Schema.Builder().addColumnInteger("col0")
            .addColumnInteger("col1").addColumnDouble("col2").build();
    String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
    TransformProcess tp = new TransformProcess.Builder(schema).transform(
            PythonTransform.builder().code(
                    "first = np.sin(first)\nsecond = np.cos(second)")
                    .outputSchema(finalSchema).build()
    ).build();
    List<List<Writable>> inputData = new ArrayList<>();
    inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));

    JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);

    List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());

    Collections.sort(out, new Comparator<List<Writable>>() {
        @Override
        public int compare(List<Writable> o1, List<Writable> o2) {
            return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
        }
    });

    List<List<Writable>> expected = new ArrayList<>();
    expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
    expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
    expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));

    assertEquals(expected, out);
}
 
Example #27
Source File: ConvertToInteger.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public IntWritable map(Writable writable) {
    if(writable.getType() == WritableType.Int){
        return (IntWritable)writable;
    }
    return new IntWritable(writable.toInt());
}
 
Example #28
Source File: SpecialImageRecordReader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> next() {
    INDArray features = Nd4j.create(channels, height, width);
    fillNDArray(features, counter.getAndIncrement());
    features = features.reshape(1, channels, height, width);
    List<Writable> ret = RecordConverter.toRecord(features);
    ret.add(new IntWritable(RandomUtils.nextInt(0, numClasses)));
    return ret;
}
 
Example #29
Source File: LibSvmRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testBasicRecord() throws IOException, InterruptedException {
    Map<Integer, List<Writable>> correct = new HashMap<>();
    // 7 2:1 4:2 6:3 8:4 10:5
    correct.put(0, Arrays.asList(ZERO, ONE,
                                ZERO, new DoubleWritable(2),
                                ZERO, new DoubleWritable(3),
                                ZERO, new DoubleWritable(4),
                                ZERO, new DoubleWritable(5),
                                new IntWritable(7)));
    // 2 qid:42 1:0.1 2:2 6:6.6 8:80
    correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
                                ZERO, ZERO,
                                ZERO, new DoubleWritable(6.6),
                                ZERO, new DoubleWritable(80),
                                ZERO, ZERO,
                                new IntWritable(2)));
    // 33
    correct.put(2, Arrays.asList(ZERO, ZERO,
                                ZERO, ZERO,
                                ZERO, ZERO,
                                ZERO, ZERO,
                                ZERO, ZERO,
                                new IntWritable(33)));

    LibSvmRecordReader rr = new LibSvmRecordReader();
    Configuration config = new Configuration();
    config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
    config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
    config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile()));
    int i = 0;
    while (rr.hasNext()) {
        List<Writable> record = rr.next();
        assertEquals(correct.get(i), record);
        i++;
    }
    assertEquals(i, correct.size());
}
 
Example #30
Source File: ExecutionTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testUniqueMultiCol(){

    Schema schema = new Schema.Builder()
            .addColumnInteger("col0")
            .addColumnCategorical("col1", "state0", "state1", "state2")
            .addColumnDouble("col2").build();

    List<List<Writable>> inputData = new ArrayList<>();
    inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
    inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));

    JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);

    Map<String,List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);

    assertEquals(2, l.size());
    List<Writable> c0 = l.get("col0");
    assertEquals(3, c0.size());
    assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2)));

    List<Writable> c1 = l.get("col1");
    assertEquals(3, c1.size());
    assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2")));
}