Java Code Examples for org.datavec.api.writable.Writable

The following examples show how to use org.datavec.api.writable.Writable. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: DataVec   Source File: JDBCRecordReader.java    License: Apache License 2.0 6 votes vote down vote up
private List<Writable> toWritable(Object[] item) {
    List<Writable> ret = new ArrayList<>();
    invokeListeners(item);
    for (int i = 0; i < item.length; i++) {
        try {
            Object columnValue = item[i];
            if (trimStrings && columnValue instanceof String) {
                columnValue = ((String) columnValue).trim();
            }
            // Note, getColumnType first argument is column number starting from 1
            Writable writable = JdbcWritableConverter.convert(columnValue, meta.getColumnType(i + 1));
            ret.add(writable);
        } catch (SQLException e) {
            closeJdbc();
            throw new RuntimeException("Error reading database metadata");
        }
    }

    return ret;
}
 
Example 2
Source Project: deeplearning4j   Source File: DoubleMetaData.java    License: Apache License 2.0 6 votes vote down vote up
@Override
public boolean isValid(Writable writable) {
    double d;
    try {
        d = writable.toDouble();
    } catch (Exception e) {
        return false;
    }

    if (allowNaN && Double.isNaN(d))
        return true;
    if (allowInfinite && Double.isInfinite(d))
        return true;

    if (minAllowedValue != null && d < minAllowedValue)
        return false;
    if (maxAllowedValue != null && d > maxAllowedValue)
        return false;

    return true;
}
 
Example 3
Source Project: DataVec   Source File: ArrowWritableRecordBatch.java    License: Apache License 2.0 6 votes vote down vote up
@Override
public List<Writable> set(int i, List<Writable> writable) {
    int rowOffset = offset + i;
    List<Writable> old = get(i);
    if(writable.size() != schema.numColumns()) {
        throw new IllegalArgumentException("Unable to set value. Wrong input types coming in");
    }

    int colIdx = 0;
    for(FieldVector fieldVector : list) {
        ArrowConverter.setValue(schema.getType(colIdx),fieldVector,writable.get(colIdx),rowOffset);
        colIdx++;
    }

    return old;
}
 
Example 4
Source Project: deeplearning4j   Source File: TimeMathOpTransform.java    License: Apache License 2.0 6 votes vote down vote up
@Override
public Writable map(Writable columnWritable) {
    long currTime = columnWritable.toLong();
    switch (mathOp) {
        case Add:
            return new LongWritable(currTime + asMilliseconds);
        case Subtract:
            return new LongWritable(currTime - asMilliseconds);
        case ScalarMax:
            return new LongWritable(Math.max(asMilliseconds, currTime));
        case ScalarMin:
            return new LongWritable(Math.min(asMilliseconds, currTime));
        default:
            throw new RuntimeException("Invalid MathOp for TimeMathOpTransform: " + mathOp);
    }
}
 
Example 5
@Override
public SequenceRecord nextSequence() {
    int lineBefore = lineIndex;
    List<List<Writable>> record = sequenceRecord();
    int lineAfter = lineIndex + queue.size();
    URI uri = (locations == null || locations.length < 1 ? null : locations[splitIndex]);
    RecordMetaData meta = new RecordMetaDataLineInterval(lineBefore, lineAfter - 1, uri,
                    CSVVariableSlidingWindowRecordReader.class);
    return new org.datavec.api.records.impl.SequenceRecord(record, meta);
}
 
Example 6
@Override
public String apply(List<List<Writable>> c) {

    StringBuilder sb = new StringBuilder();
    boolean firstLine = true;
    for (List<Writable> timeStep : c) {
        if (!firstLine) {
            sb.append(timeStepDelimiter);
        }
        WritablesToStringFunction.append(timeStep, sb, delimiter, quote);
        firstLine = false;
    }

    return sb.toString();
}
 
Example 7
Source Project: deeplearning4j   Source File: JacksonRecordReaderTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testAppendingLabelsMetaData() throws Exception {
    ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
    File f = testDir.newFolder();
    cpr.copyDirectory(f);
    String path = new File(f, "json_test_%d.txt").getAbsolutePath();

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    //Insert at the end:
    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
                    new LabelGen());
    rr.initialize(is);

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.next());
    }
    assertEquals(3, out.size());

    rr.reset();

    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> outRecord = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        outRecord.add(r);
        meta.add(r.getMetaData());
    }

    assertEquals(out, out2);

    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(outRecord, fromMeta);
}
 
Example 8
Source Project: deeplearning4j   Source File: BaseColumnFilter.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public boolean removeSequence(List<List<Writable>> sequence) {
    for (List<Writable> c : sequence) {
        if (removeExample(c))
            return true;
    }
    return false;
}
 
Example 9
@Override
public void processValidWritable(Writable writable, List<Writable> record, int inputIndex, Object... extraArgs) {
    String inputName = imageLoadingStepConfig.getInputNames().get(inputIndex);

    NativeImageLoader nativeImageLoader = imageLoaders.get(inputName);

    ImageTransformProcess imageTransformProcess = null;
    if (imageLoadingStepConfig.getImageTransformProcesses() != null) {
        imageTransformProcess = imageLoadingStepConfig.getImageTransformProcesses().get(inputName);
    }

    INDArray input;

    try {
        if (writable instanceof ImageWritable) {
            input = nativeImageLoader.asMatrix(((ImageWritable) writable).getFrame());
        } else if (writable instanceof BytesWritable) {
            input = nativeImageLoader.asMatrix(((BytesWritable) writable).getContent());
        } else if (writable instanceof Text) {
            input = nativeImageLoader.asMatrix(writable.toString());
        } else if (writable instanceof NDArrayWritable) {
            input = ((NDArrayWritable) writable).get();
        } else {
            throw new IllegalArgumentException("Illegal type to load from " + writable.getClass());
        }

        INDArray output;

        if (imageLoadingStepConfig.isUpdateOrderingBeforeTransform()) {
            output = applyTransform(imageTransformProcess, nativeImageLoader, permuteImageOrder(input));
        } else {
            output = permuteImageOrder(applyTransform(imageTransformProcess, nativeImageLoader, input));
        }

        record.add(new NDArrayWritable(output));
    } catch (IOException e) {
        e.printStackTrace();
    }
}
 
Example 10
Source Project: deeplearning4j   Source File: StringAggregatorImpls.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public <W extends IAggregableReduceOp<String, Writable>> void combine(W accu) {
    if (accu instanceof AggregableStringPrepend)
        sb.append(((AggregableStringPrepend) accu).getSb());
    else
        throw new UnsupportedOperationException("Tried to combine() incompatible " + accu.getClass().getName()
                        + " operator where" + this.getClass().getName() + " expected");
}
 
Example 11
@Test
public void testRecordReaderDataSetIteratorConcat2() {
    List<Writable> l = new ArrayList<>();
    l.add(new IntWritable(0));
    l.add(new NDArrayWritable(Nd4j.arange(1, 9)));
    l.add(new IntWritable(9));

    RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1);

    DataSet ds = iter.next();
    INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,10});

    assertEquals(expF, ds.getFeatures());
}
 
Example 12
Source Project: DataVec   Source File: FloatWritableOp.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public <W extends IAggregableReduceOp<Writable, T>> void combine(W accu) {
    if (accu instanceof FloatWritableOp)
        operation.combine(((FloatWritableOp) accu).getOperation());
    else
        throw new UnsupportedOperationException("Tried to combine() incompatible " + accu.getClass().getName()
                        + " operator where " + this.getClass().getName() + " expected");
}
 
Example 13
Source Project: deeplearning4j   Source File: CoordinatesReduction.java    License: Apache License 2.0 5 votes vote down vote up
public CoordinateAggregableReduceOp(int n, Supplier<IAggregableReduceOp<Writable, List<Writable>>> initialOp,
                String delim) {
    this.nOps = n;
    this.perCoordinateOps = new ArrayList<>();
    this.initialOpValue = initialOp;
    this.delimiter = delim;
}
 
Example 14
Source Project: deeplearning4j   Source File: LongWritableOp.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public <W extends IAggregableReduceOp<Writable, T>> void combine(W accu) {
    if (accu instanceof LongWritableOp)
        operation.combine(((LongWritableOp) accu).getOperation());
    else
        throw new UnsupportedOperationException("Tried to combine() incompatible " + accu.getClass().getName()
                        + " operator where " + this.getClass().getName() + " expected");
}
 
Example 15
@Override
public Pair<List<Writable>, List<Writable>> apply(List<Writable> writables) {
    List<Writable> keyOut = new ArrayList<>(keyColumnIdxs.length);
    for (int keyColumnIdx : keyColumnIdxs) {
        keyOut.add(writables.get(keyColumnIdx));
    }
    return Pair.of(keyOut, writables);
}
 
Example 16
Source Project: deeplearning4j   Source File: RecordReaderFunction.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> apply(Pair<String, InputStream> value) {
    URI uri = URI.create(value.getFirst());
    InputStream ds = value.getRight();
    try (DataInputStream dis = (DataInputStream) ds) {
        return recordReader.record(uri, dis);
    } catch (IOException e) {
        throw new IllegalStateException("Something went wrong reading file");
    }

}
 
Example 17
Source Project: DataVec   Source File: SequenceBatchCSVRecord.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Convert a writables time series to a sequence batch
 * @param input
 * @return
 */
public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) {
    SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord();
    for(int i = 0; i < input.size(); i++) {
        ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i))));
    }

    return ret;
}
 
Example 18
Source Project: deeplearning4j   Source File: CSVRecordReaderTest.java    License: Apache License 2.0 5 votes vote down vote up
@Test
public void testTabsAsSplit1() throws Exception {

    CSVRecordReader reader = new CSVRecordReader(0, '\t');
    reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
    while (reader.hasNext()) {
        List<Writable> list = new ArrayList<>(reader.next());

        assertEquals(2, list.size());
    }
}
 
Example 19
Source Project: deeplearning4j   Source File: TypeConversion.java    License: Apache License 2.0 5 votes vote down vote up
public int convertInt(Object o) {
    if(o instanceof Writable) {
        Writable writable = (Writable) o;
        return convertInt(writable);
    }
    else {
        return convertInt(o.toString());
    }
}
 
Example 20
Source Project: DataVec   Source File: AggregatorImpls.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public Writable get() {
    if (override == null)
        return UnsafeWritableInjector.inject(elem);
    else
        return override;
}
 
Example 21
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
    int maxLinesPerSequence = 3;

    SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
    seqRR.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    CSVRecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int count = 0;
    while (seqRR.hasNext()) {
        List<List<Writable>> next = seqRR.sequenceRecord();

        if(count==maxLinesPerSequence-1) {
            LinkedList<List<Writable>> expected = new LinkedList<>();
            for (int i = 0; i < maxLinesPerSequence; i++) {
                expected.addFirst(rr.next());
            }
            assertEquals(expected, next);

        }
        if(count==maxLinesPerSequence) {
            assertEquals(maxLinesPerSequence, next.size());
        }
        if(count==0) { // first seq should be length 1
            assertEquals(1, next.size());
        }
        if(count>151) { // last seq should be length 1
            assertEquals(1, next.size());
        }

        count++;
    }

    assertEquals(152, count);
}
 
Example 22
protected SequenceRecord convert(Record r){
    List<Writable> line = r.getRecord();
    List<List<Writable>> out = new ArrayList<>();
    for(Writable w : line){
        out.add(Collections.singletonList(w));
    }
    return new org.datavec.api.records.impl.SequenceRecord(out, r.getMetaData());
}
 
Example 23
Source Project: deeplearning4j   Source File: ExtractKeysFunction.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public Pair<List<Writable>, List<Writable>> apply(List<Writable> writables) {

    List<Writable> keyValues;
    if (columnIndexes.length == 1) {
        keyValues = Collections.singletonList(writables.get(columnIndexes[0]));
    } else {
        keyValues = new ArrayList<>(columnIndexes.length);
        for (int i : columnIndexes) {
            keyValues.add(writables.get(i));
        }
    }

    return Pair.of(keyValues, writables);
}
 
Example 24
@Test
public void test() throws Exception {

    File f = testDir.newFolder();
    File source = new File(f, "temp.csv");
    String str = "a,b,c\n1,2,3,4";
    FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);

    SequenceRecordReader rr = new CSVLineSequenceRecordReader();
    rr.initialize(new FileSplit(source));

    List<List<Writable>> exp0 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("a")),
            Collections.<Writable>singletonList(new Text("b")),
            Collections.<Writable>singletonList(new Text("c")));

    List<List<Writable>> exp1 = Arrays.asList(
            Collections.<Writable>singletonList(new Text("1")),
            Collections.<Writable>singletonList(new Text("2")),
            Collections.<Writable>singletonList(new Text("3")),
            Collections.<Writable>singletonList(new Text("4")));

    for( int i=0; i<3; i++ ) {
        int count = 0;
        while (rr.hasNext()) {
            List<List<Writable>> next = rr.sequenceRecord();
            if (count++ == 0) {
                assertEquals(exp0, next);
            } else {
                assertEquals(exp1, next);
            }
        }

        assertEquals(2, count);

        rr.reset();
    }
}
 
Example 25
Source Project: DataVec   Source File: SequenceToRowsAdapter.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public Iterable<Row> call(List<List<Writable>> sequence) throws Exception {
    if (sequence.size() == 0)
        return Collections.emptyList();

    String sequenceUUID = UUID.randomUUID().toString();

    List<Row> out = new ArrayList<>(sequence.size());

    int stepCount = 0;
    for (List<Writable> step : sequence) {
        Object[] values = new Object[step.size() + 2];
        values[0] = sequenceUUID;
        values[1] = stepCount++;
        for (int i = 0; i < step.size(); i++) {
            switch (schema.getColumnTypes().get(i)) {
                case Double:
                    values[i + 2] = step.get(i).toDouble();
                    break;
                case Integer:
                    values[i + 2] = step.get(i).toInt();
                    break;
                case Long:
                    values[i + 2] = step.get(i).toLong();
                    break;
                case Float:
                    values[i + 2] = step.get(i).toFloat();
                    break;
                default:
                    throw new IllegalStateException(
                                    "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
            }
        }

        Row row = new GenericRowWithSchema(values, structType);
        out.add(row);
    }

    return out;
}
 
Example 26
@Override
public Writable map(Writable writable) {
    String s = writable.toString();
    if (s == null || s.isEmpty())
        return new IntWritable(value);
    return writable;
}
 
Example 27
@Test
public void testLineRecordReader() throws Exception {

    File dataFile = new ClassPathResource("iris.dat").getFile();
    List<String> lines = FileUtils.readLines(dataFile);

    JavaSparkContext sc = getContext();
    JavaRDD<String> linesRdd = sc.parallelize(lines);

    CSVRecordReader rr = new CSVRecordReader(0, ',');

    JavaRDD<List<Writable>> out = linesRdd.map(new LineRecordReaderFunction(rr));
    List<List<Writable>> outList = out.collect();


    CSVRecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(dataFile));
    Set<List<Writable>> expectedSet = new HashSet<>();
    int totalCount = 0;
    while (rr2.hasNext()) {
        expectedSet.add(rr2.next());
        totalCount++;
    }

    assertEquals(totalCount, outList.size());

    for (List<Writable> line : outList) {
        assertTrue(expectedSet.contains(line));
    }
}
 
Example 28
Source Project: scava   Source File: VasttextTextVectorizer.java    License: Eclipse Public License 2.0 5 votes vote down vote up
protected List<String> tokensFromRecord(Writable writable)
{

	String text = writable.toString();
	Tokenizer tokenizer = tokenizerFactory.create(text);
	List<String> tokens = new ArrayList<String>();
	while (tokenizer.hasMoreTokens())
        tokens.add(tokenizer.nextToken());
	return tokens;
}
 
Example 29
@Test
public void testNDArrayToWritablesScalars() throws Exception {
    INDArray arr = Nd4j.arange(5);
    List<Writable> expected = new ArrayList<>();
    for (int i = 0; i < 5; i++)
        expected.add(new DoubleWritable(i));
    List<Writable> actual = new NDArrayToWritablesFunction().call(arr);
    assertEquals(expected, actual);
}
 
Example 30
Source Project: deeplearning4j   Source File: JacksonRecordReader.java    License: Apache License 2.0 5 votes vote down vote up
@Override
public Record nextRecord() {
    URI currentURI = uris[cursor];
    List<Writable> writables = next();
    RecordMetaData meta = new RecordMetaDataURI(currentURI, JacksonRecordReader.class);
    return new org.datavec.api.records.impl.Record(writables, meta);
}