org.datavec.api.writable.NDArrayWritable Java Examples

The following examples show how to use org.datavec.api.writable.NDArrayWritable. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRecordReaderDataSetIteratorDisjointFeatures() {

    //Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end

    List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
                    new NDArrayWritable(Nd4j.create(new float[] {2, 3, 4}, new long[]{1,3})), new DoubleWritable(5),
                    new NDArrayWritable(Nd4j.create(new float[] {6, 7, 8}, new long[]{1,3})));

    INDArray expF = Nd4j.create(new float[] {1, 6, 7, 8}, new long[]{1,4});
    INDArray expL = Nd4j.create(new float[] {2, 3, 4, 5}, new long[]{1,4});

    RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));

    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 1, 2, true);

    DataSet ds = iter.next();
    assertEquals(expF, ds.getFeatures());
    assertEquals(expL, ds.getLabels());
}
 
Example #2
Source File: InferenceExecutionerStepRunner.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
private Record[] toNDArray(Record[] records) {
    if (records[0].getRecord().size() > 1 && !recordIsAllNumeric(records[0])) {
        throw new IllegalArgumentException("Invalid record type passed in. This pipeline only accepts records with singular ndarray records representing 1 input array per name for input graphs or purely numeric arrays that can be converted to a matrix");
    } else if (allNdArray(records)) {
        return records;
    } else {
        INDArray arr = Nd4j.create(records.length, records[0].getRecord().size());
        for (int i = 0; i < arr.rows(); i++) {
            for (int j = 0; j < arr.columns(); j++) {
                arr.putScalar(i, j, records[i].getRecord().get(j).toDouble());
            }
        }

        return new Record[]{
                new org.datavec.api.records.impl.Record(
                        Collections.singletonList(new NDArrayWritable(arr))
                        , null
                )};
    }
}
 
Example #3
Source File: VertxBufferNumpyInputAdapter.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
/**
 * Convert Buffer input to NDArray writable. Note that contextData is unused in this implementation of InputAdapter.
 */
@Override
public NDArrayWritable convert(Buffer input, ConverterArgs parameters, Map<String, Object> contextData) {
    Preconditions.checkState(input.length() > 0, "Buffer appears to be empty!");
    INDArray fromNpyPointer = Nd4j.getNDArrayFactory().createFromNpyPointer(
            new BytePointer(input.getByteBuf().nioBuffer())
    );
    if (permuteRequired(parameters)) {
        fromNpyPointer = ImagePermuter.permuteOrder(
                fromNpyPointer,
                parameters.getImageProcessingInitialLayout(),
                parameters.getImageProcessingRequiredLayout()
        );
    }

    return new NDArrayWritable(fromNpyPointer);
}
 
Example #4
Source File: TestWritablesToNDArrayFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWritablesToNDArrayMixed() throws Exception {
    Nd4j.setDataType(DataType.FLOAT);
    List<Writable> l = new ArrayList<>();
    l.add(new IntWritable(0));
    l.add(new IntWritable(1));
    INDArray arr = Nd4j.arange(2, 5).reshape(1, 3);
    l.add(new NDArrayWritable(arr));
    l.add(new IntWritable(5));
    arr = Nd4j.arange(6, 9).reshape(1, 3);
    l.add(new NDArrayWritable(arr));
    l.add(new IntWritable(9));

    INDArray expected = Nd4j.arange(10).castTo(DataType.FLOAT).reshape(1, 10);
    assertEquals(expected, new WritablesToNDArrayFunction().apply(l));
}
 
Example #5
Source File: NDArrayAnalysisCounter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public NDArrayAnalysisCounter add(Writable writable) {
    NDArrayWritable n = (NDArrayWritable) writable;
    INDArray arr = n.get();
    countTotal++;
    if (arr == null) {
        countNull++;
    } else {
        minLength = Math.min(minLength, arr.length());
        maxLength = Math.max(maxLength, arr.length());

        int r = arr.rank();
        if (countsByRank.containsKey(arr.rank())) {
            countsByRank.put(r, countsByRank.get(r) + 1);
        } else {
            countsByRank.put(r, 1L);
        }

        totalNDArrayValues += arr.length();
        minValue = Math.min(minValue, arr.minNumber().doubleValue());
        maxValue = Math.max(maxValue, arr.maxNumber().doubleValue());
    }

    return this;
}
 
Example #6
Source File: TfidfRecordReader.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
    List<Record> out = new ArrayList<>();

    for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
        INDArray transform = tfidfVectorizer.transform(fileContents);

        org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
                        new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
                        new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));

        if (appendLabel)
            record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
        out.add(record);
    }

    return out;
}
 
Example #7
Source File: TestImageRecordReader.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private static List<Writable> testMultiLabel(String filename){
    switch(filename){
        case "0.jpg":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{1,0,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(0.0));
        case "1.png":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{0,1,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(1.0));
        case "2.jpg":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{0,0,1}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(2.0));
        case "A.jpg":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{1,0,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(3.0));
        case "B.png":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{0,1,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(4.0));
        case "C.jpg":
            return Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)),
                    new NDArrayWritable(Nd4j.create(new double[]{0,0,1}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5.0));
        default:
            throw new RuntimeException(filename);
    }
}
 
Example #8
Source File: RecordReaderMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private int countLength(List<Writable> list, int from, int to) {
    int length = 0;
    for (int i = from; i <= to; i++) {
        Writable w = list.get(i);
        if (w instanceof NDArrayWritable) {
            INDArray a = ((NDArrayWritable) w).get();
            if (!a.isRowVectorOrScalar()) {
                throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is "
                                + "not a row vector. Can only concat row vectors with other writables. Shape: "
                                + Arrays.toString(a.shape()));
            }
            length += a.length();
        } else {
            //Assume all others are single value
            length++;
        }
    }

    return length;
}
 
Example #9
Source File: NDArrayHistogramCounter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public HistogramCounter add(Writable w) {
    INDArray arr = ((NDArrayWritable) w).get();
    if (arr == null) {
        return this;
    }

    long length = arr.length();
    DoubleWritable dw = new DoubleWritable();
    for (int i = 0; i < length; i++) {
        dw.set(arr.getDouble(i));
        underlying.add(dw);
    }

    return this;
}
 
Example #10
Source File: VasttextTextFileReader.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
    List<Record> out = new ArrayList<>();

    for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
    	INDArray transformed = vasttextTextVectorizer.transform(fileContents);

       org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record(
    		   			new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transformed))),
    		   			new RecordMetaDataURI(fileContents.getMetaData().getURI(), VasttextTextFileReader.class));
       if (labelled)
    	   transformedRecord.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
        out.add(transformedRecord);
    }

    return out;
}
 
Example #11
Source File: VasttextExtraFileReader.java    From scava with Eclipse Public License 2.0 6 votes vote down vote up
public Record processNextRecord() {
    //We need to split and find the label(s)
    String[] line = super.next().get(0).toString().split(" ");
    double[] extraFeatures = new double[numericFeaturesSize];
    if(line.length != numericFeaturesSize)
    	 throw new UnsupportedOperationException("Features defined and features found do not match. Found: "+ line.length + " Declared:" +numericFeaturesSize);
    for(int i=0; i<numericFeaturesSize; i++)
    {
    	extraFeatures[i]=Double.valueOf(line[i]);
    }
    INDArray transformed = Nd4j.create(extraFeatures,new int[]{extraFeatures.length,1});
    
    URI uri = (locations == null || locations.length < 1 ? null : locations[splitIndex]);
    RecordMetaData meta = new RecordMetaDataLine(this.lineIndex - 1, uri, LineRecordReader.class); //-1 as line number has been incremented already...
    return new org.datavec.api.records.impl.Record(new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transformed)))
    		, meta);
}
 
Example #12
Source File: NDArrayHistogramCounter.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public HistogramCounter add(Writable w) {
    INDArray arr = ((NDArrayWritable) w).get();
    if (arr == null) {
        return this;
    }

    long length = arr.length();
    DoubleWritable dw = new DoubleWritable();
    for (int i = 0; i < length; i++) {
        dw.set(arr.getDouble(i));
        underlying.add(dw);
    }

    return this;
}
 
Example #13
Source File: TfidfRecordReader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
    List<Record> out = new ArrayList<>();

    for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) {
        INDArray transform = tfidfVectorizer.transform(fileContents);

        org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record(
                        new ArrayList<>(Collections.<Writable>singletonList(new NDArrayWritable(transform))),
                        new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class));

        if (appendLabel)
            record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1));
        out.add(record);
    }

    return out;
}
 
Example #14
Source File: TfidfRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testReader() throws Exception {
    TfidfVectorizer vectorizer = new TfidfVectorizer();
    Configuration conf = new Configuration();
    conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1);
    conf.setBoolean(RecordReader.APPEND_LABEL, true);
    vectorizer.initialize(conf);
    TfidfRecordReader reader = new TfidfRecordReader();
    reader.initialize(conf, new FileSplit(new ClassPathResource("labeled").getFile()));
    int count = 0;
    int[] labelAssertions = new int[3];
    while (reader.hasNext()) {
        Collection<Writable> record = reader.next();
        Iterator<Writable> recordIter = record.iterator();
        NDArrayWritable writable = (NDArrayWritable) recordIter.next();
        labelAssertions[count] = recordIter.next().toInt();
        count++;
    }

    assertArrayEquals(new int[] {0, 1, 2}, labelAssertions);
    assertEquals(3, reader.getLabels().size());
    assertEquals(3, count);
}
 
Example #15
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testRecordReaderDataSetIteratorConcat() {

    //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically.

    List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
                    new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5),
                    new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9),
                    new IntWritable(1));

    RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));

    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3);

    DataSet ds = iter.next();
    INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9});
    INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3});

    assertEquals(expF, ds.getFeatures());
    assertEquals(expL, ds.getLabels());
}
 
Example #16
Source File: NDArrayToWritablesFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> apply(INDArray arr) {
    if (arr.rows() != 1)
        throw new UnsupportedOperationException("Only NDArray row vectors can be converted to list"
                                            + " of Writables (found " + arr.rows() + " rows)");
    List<Writable> record = new ArrayList<>();
    if (useNdarrayWritable) {
        record.add(new NDArrayWritable(arr));
    } else {
        for (int i = 0; i < arr.columns(); i++)
            record.add(new DoubleWritable(arr.getDouble(i)));
    }
    return record;
}
 
Example #17
Source File: TokenizerBagOfWordsTermSequenceIndexTransform.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> map(List<Writable> writables) {
    Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName));
    List<Writable> ret = new ArrayList<>(writables);
    ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString())));
    return ret;
}
 
Example #18
Source File: LibSvmRecordWriterTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #19
Source File: NDArrayDistanceTransform.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> map(List<Writable> writables) {
    int idxFirst = inputSchema.getIndexOfColumn(firstCol);
    int idxSecond = inputSchema.getIndexOfColumn(secondCol);

    INDArray arr1 = ((NDArrayWritable) writables.get(idxFirst)).get();
    INDArray arr2 = ((NDArrayWritable) writables.get(idxSecond)).get();

    double d;
    switch (distance) {
        case COSINE:
            d = Transforms.cosineSim(arr1, arr2);
            break;
        case EUCLIDEAN:
            d = Transforms.euclideanDistance(arr1, arr2);
            break;
        case MANHATTAN:
            d = Transforms.manhattanDistance(arr1, arr2);
            break;
        default:
            throw new UnsupportedOperationException("Unknown or not supported distance metric: " + distance);
    }

    List<Writable> out = new ArrayList<>(writables.size() + 1);
    out.addAll(writables);
    out.add(new DoubleWritable(d));

    return out;
}
 
Example #20
Source File: TestNDArrayToWritablesFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayToWritablesArray() throws Exception {
    INDArray arr = Nd4j.arange(5);
    List<Writable> expected = Arrays.asList((Writable) new NDArrayWritable(arr));
    List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr);
    assertEquals(expected, actual);
}
 
Example #21
Source File: LibSvmRecordWriterTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #22
Source File: NDArrayMetaData.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public boolean isValid(Object input) {
    if (input == null) {
        return false;
    } else if (input instanceof Writable) {
        return isValid((Writable) input);
    } else if (input instanceof INDArray) {
        return isValid(new NDArrayWritable((INDArray) input));
    } else {
        throw new UnsupportedOperationException("Unknown object type: " + input.getClass());
    }
}
 
Example #23
Source File: SVMLightRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";

    try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #24
Source File: SVMLightRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";

    try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #25
Source File: SVMLightRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritables() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 13);
    arr3.putScalar(1, 14);
    arr3.putScalar(2, 15);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
                                        new NDArrayWritable(arr2),
                                        new IntWritable(2),
                                        new DoubleWritable(3),
                                        new NDArrayWritable(arr3),
                                        new IntWritable(4));
    File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";

    try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #26
Source File: RecordReaderDataSetiteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSeqRRDSIArrayWritableOneReader() {

    List<List<Writable>> sequence1 = new ArrayList<>();
    sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})),
                    new IntWritable(0)));
    sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})),
                    new IntWritable(1)));
    List<List<Writable>> sequence2 = new ArrayList<>();
    sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})),
                    new IntWritable(2)));
    sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})),
                    new IntWritable(3)));


    SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2));

    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, 4, 1, false);

    DataSet ds = iter.next();

    INDArray expFeatures = Nd4j.create(2, 3, 2); //2 examples, 3 values per time step, 2 time steps
    expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}}));
    expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}}));

    INDArray expLabels = Nd4j.create(2, 4, 2);
    expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 0}, {0, 1}, {0, 0}, {0, 0}}));
    expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{0, 0}, {0, 0}, {1, 0}, {0, 1}}));

    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
}
 
Example #27
Source File: NDArrayRecordBatch.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> get(int index) {
    Preconditions.checkArgument(index >= 0 && index < size, "Invalid index: " + index + ", size = " + size);
    List<Writable> out = new ArrayList<>((int) size);
    for (INDArray orig : arrays) {
        INDArray view = getExample(index, orig);
        out.add(new NDArrayWritable(view));
    }
    return out;
}
 
Example #28
Source File: SchemaTypeUtils.java    From konduit-serving with Apache License 2.0 5 votes vote down vote up
/**
 * Convert an {@link INDArray}
 * batch to {@link Record}
 * input comprising of a single {@link NDArrayWritable}
 *
 * @param input the input
 * @return the equivalent output records
 */
public static Record[] toRecords(INDArray[] input) {
    Record[] ret = new Record[input.length];
    for (int i = 0; i < ret.length; i++) {
        ret[i] = new org.datavec.api.records.impl.Record(
                Arrays.asList(new NDArrayWritable(input[i]))
                , null);
    }

    return ret;
}
 
Example #29
Source File: SVMLightRecordWriterTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";

    try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}
 
Example #30
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
    INDArray arr2 = Nd4j.zeros(2);
    arr2.putScalar(0, 11);
    arr2.putScalar(1, 12);
    INDArray arr3 = Nd4j.zeros(3);
    arr3.putScalar(0, 0);
    arr3.putScalar(1, 1);
    arr3.putScalar(2, 0);
    List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
            new NDArrayWritable(arr2),
            new IntWritable(2),
            new DoubleWritable(3),
            new NDArrayWritable(arr3),
            new DoubleWritable(1));
    File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
    tempFile.setWritable(true);
    tempFile.deleteOnExit();
    if (tempFile.exists())
        tempFile.delete();

    String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";

    try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
        Configuration configWriter = new Configuration();
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
        configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
        configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
        FileSplit outputSplit = new FileSplit(tempFile);
        writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
        writer.write(record);
    }

    String lineNew = FileUtils.readFileToString(tempFile).trim();
    assertEquals(lineOriginal, lineNew);
}