org.nd4j.linalg.io.ClassPathResource Java Examples

The following examples show how to use org.nd4j.linalg.io.ClassPathResource. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TransformProcessRecordReaderTests.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void simpleTransformTest() throws Exception {
    Schema schema = new Schema.Builder()
            .addColumnsDouble("%d", 0, 4)
            .build();
    TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("0").build();
    CSVRecordReader csvRecordReader = new CSVRecordReader();
    csvRecordReader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
    TransformProcessRecordReader rr =
                    new TransformProcessRecordReader(csvRecordReader, transformProcess);
    int count = 0;
    List<List<Writable>> all = new ArrayList<>();
    while(rr.hasNext()){
        List<Writable> next = rr.next();
        assertEquals(4, next.size());
        count++;
        all.add(next);
    }
    assertEquals(150, count);

    //Test batch:
    assertTrue(rr.resetSupported());
    rr.reset();
    List<List<Writable>> batch = rr.next(150);
    assertEquals(all, batch);
}
 
Example #2
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testImportMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream());

    val variables = new HashMap<String, SDVariable>();
    for (val var : tg.variables()) {
        variables.put(var.getVarName(), var);
    }

    val functions = new HashMap<String, DifferentialFunction>();
    for (val func: tg.functions()) {
        val ownName = func.getOwnName();
        val outName = func.outputVariables()[0].getVarName();

        assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
        assertEquals(ownName, outName);
    }
}
 
Example #3
Source File: CodecReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testNativeCodecReader() throws Exception {
    File file = new ClassPathResource("fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new NativeCodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    //        System.out.println(record.size());

    Iterator<List<Writable>> it = record.iterator();
    List<Writable> first = it.next();
    //        System.out.println(first);

    //Expected size: 80x46x3
    assertEquals(1, first.size());
    assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length());
}
 
Example #4
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testCrash_119_transpose() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3});
    val input1 = Nd4j.create(new double[]{0.98114507, 0.60073098, 0.76373084, 0.96400015, 0.75425418, 0.96593234, 0.58669623, 0.44258752, 0.34067846}, new int[] {3, 3});

    tg.associateArrayWithVariable(input0, tg.getVariable("input"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/transpose.fb"));
}
 
Example #5
Source File: OnnxGraphMapperTests.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMapper() throws Exception {
    try(val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
        OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
        OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
        assertEquals(graphProto.getNodeList().size(),
                onnxGraphMapper.getNodeList(graphProto).size());
        assertEquals(4,onnxGraphMapper.variablesForGraph(graphProto).size());
        val initializer = graphProto.getInput(0).getType().getTensorType();
        INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
        assumeNotNull(arr);
        for(val node : graphProto.getNodeList()) {
            assertEquals(node.getAttributeList().size(),onnxGraphMapper.getAttrMap(node).size());
        }

        val sameDiff = onnxGraphMapper.importGraph(graphProto);
        assertEquals(1,sameDiff.functions().length);
        System.out.println(sameDiff);
    }

}
 
Example #6
Source File: IrisUtils.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static List<DataSet> loadIris(int from, int to) throws IOException {
    ClassPathResource resource = new ClassPathResource("/iris.dat", IrisUtils.class.getClassLoader());
    @SuppressWarnings("unchecked")
    List<String> lines = IOUtils.readLines(resource.getInputStream());
    List<DataSet> list = new ArrayList<>();
    INDArray ret = Nd4j.ones(Math.abs(to - from), 4);
    double[][] outcomes = new double[lines.size()][3];
    int putCount = 0;

    for (int i = from; i < to; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");

        addRow(ret, putCount++, split);

        String outcome = split[split.length - 1];
        double[] rowOutcome = new double[3];
        rowOutcome[Integer.parseInt(outcome)] = 1;
        outcomes[i] = rowOutcome;
    }

    for (int i = 0; i < ret.rows(); i++)
        list.add(new DataSet(ret.getRow(i), Nd4j.create(outcomes[from + i])));

    return list;
}
 
Example #7
Source File: FileRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testReset() throws Exception {
    FileRecordReader rr = new FileRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));

    int nResets = 5;
    for (int i = 0; i < nResets; i++) {

        int lineCount = 0;
        while (rr.hasNext()) {
            List<Writable> line = rr.next();
            assertEquals(1, line.size());
            lineCount++;
        }
        assertFalse(rr.hasNext());
        assertEquals(1, lineCount);
        rr.reset();
    }
}
 
Example #8
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));

    //log.info("{}", tg.asFlatPrint());
    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example #9
Source File: ImageSparkTransformTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testBatchImageSparkTransform() throws Exception {
    int seed = 12345;

    File f0 = new ClassPathResource("/testimages/class1/A.jpg").getFile();
    File f1 = new ClassPathResource("/testimages/class1/B.png").getFile();
    File f2 = new ClassPathResource("/testimages/class1/C.jpg").getFile();

    BatchImageRecord batch = new BatchImageRecord();
    batch.add(f0.toURI());
    batch.add(f1.toURI());
    batch.add(f2.toURI());

    ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
                    .scaleImageTransform(10).cropImageTransform(5).build();

    ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
    Base64NDArrayBody body = imgSparkTransform.toArray(batch);

    INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
    System.out.println("Base 64ed array " + fromBase64);
    assertEquals(3, fromBase64.size(0));
}
 
Example #10
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testZeroBasedIndexing() throws Exception {
    Configuration configWriter = new Configuration();
    configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
    configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
    configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
    configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);

    Configuration configReader = new Configuration();
    configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
    configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
    configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);

    File inputFile = new ClassPathResource("svmlight/multilabel.txt").getFile();
    executeTest(configWriter, configReader, inputFile);
}
 
Example #11
Source File: HyperParameterTuningArbiterUiExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public RecordReader dataPreprocess() throws IOException, InterruptedException {
    //Schema Definitions
    Schema schema = new Schema.Builder()
            .addColumnsString("RowNumber")
            .addColumnInteger("CustomerId")
            .addColumnString("Surname")
            .addColumnInteger("CreditScore")
            .addColumnCategorical("Geography", Arrays.asList("France","Spain","Germany"))
            .addColumnCategorical("Gender",Arrays.asList("Male","Female"))
            .addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build();

    //Schema Transformation
    TransformProcess transformProcess = new TransformProcess.Builder(schema)
            .removeColumns("RowNumber","Surname","CustomerId")
            .categoricalToInteger("Gender")
            .categoricalToOneHot("Geography")
            .removeColumns("Geography[France]")
            .build();

    //CSVReader - Reading from file and applying transformation
    RecordReader reader = new CSVRecordReader(1,',');
    reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
    RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess);
    return transformProcessRecordReader;
}
 
Example #12
Source File: BasicLineIteratorExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    SentenceIterator iterator = new BasicLineIterator(new ClassPathResource("raw_sentences.txt").getFile());
    int count=0;
    while(iterator.hasNext()){
       iterator.nextSentence();
       count++;
    }
    System.out.println("count = "+count);
    iterator.reset();
    SentenceDataPreProcessor.setPreprocessor(iterator);
    while(iterator.hasNext()){
        System.out.println(iterator.nextSentence());
    }
    

}
 
Example #13
Source File: RegexRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testRegexLineRecordReader() throws Exception {
    String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";

    RecordReader rr = new RegexLineRecordReader(regex, 1);
    rr.initialize(new FileSplit(new ClassPathResource("/logtestdata/logtestfile0.txt").getFile()));

    List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
                    new Text("DEBUG"), new Text("First entry message!"));
    List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
                    new Text("INFO"), new Text("Second entry message!"));
    List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
                    new Text("WARN"), new Text("Third entry message!"));
    assertEquals(exp0, rr.next());
    assertEquals(exp1, rr.next());
    assertEquals(exp2, rr.next());
    assertFalse(rr.hasNext());

    //Test reset:
    rr.reset();
    assertEquals(exp0, rr.next());
    assertEquals(exp1, rr.next());
    assertEquals(exp2, rr.next());
    assertFalse(rr.hasNext());
}
 
Example #14
Source File: SVMLightRecordReaderTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testNextRecord() throws IOException, InterruptedException {
    SVMLightRecordReader rr = new SVMLightRecordReader();
    Configuration config = new Configuration();
    config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
    config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
    config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile()));

    Record record = rr.nextRecord();
    List<Writable> recordList = record.getRecord();
    assertEquals(new DoubleWritable(1.0), recordList.get(1));
    assertEquals(new DoubleWritable(3.0), recordList.get(5));
    assertEquals(new DoubleWritable(4.0), recordList.get(7));

    record = rr.nextRecord();
    recordList = record.getRecord();
    assertEquals(new DoubleWritable(0.1), recordList.get(0));
    assertEquals(new DoubleWritable(6.6), recordList.get(5));
    assertEquals(new DoubleWritable(80.0), recordList.get(7));
}
 
Example #15
Source File: LibSvmRecordWriterTest.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testMultilabelRecord() throws Exception {
    Configuration configWriter = new Configuration();
    configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
    configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
    configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);

    Configuration configReader = new Configuration();
    configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
    configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
    configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
    configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);

    File inputFile = new ClassPathResource("svmlight/multilabel.txt").getFile();
    executeTest(configWriter, configReader, inputFile);
}
 
Example #16
Source File: TestNativeImageLoader.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testAsWritable() throws Exception {
    File f0 = new ClassPathResource("/testimages/class0/0.jpg").getFile();

    NativeImageLoader imageLoader = new NativeImageLoader();
    ImageWritable img = imageLoader.asWritable(f0);

    assertEquals(32, img.getFrame().imageHeight);
    assertEquals(32, img.getFrame().imageWidth);
    assertEquals(3, img.getFrame().imageChannels);

    BufferedImage img1 = makeRandomBufferedImage(0, 0, 3);
    Mat img2 = makeRandomImage(0, 0, 4);

    int w1 = 33, h1 = 77, ch1 = 1;
    NativeImageLoader loader1 = new NativeImageLoader(h1, w1, ch1);

    INDArray array1 = loader1.asMatrix(f0);
    assertEquals(4, array1.rank());
    assertEquals(1, array1.size(0));
    assertEquals(1, array1.size(1));
    assertEquals(h1, array1.size(2));
    assertEquals(w1, array1.size(3));
}
 
Example #17
Source File: TestSerialization.java    From DataVec with Apache License 2.0 6 votes vote down vote up
@Test
public void testCsvRRSerializationResults() throws Exception {
    int skipLines = 3;
    RecordReader r1 = new CSVRecordReader(skipLines, '\t');
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ObjectOutputStream os = new ObjectOutputStream(baos);
    os.writeObject(r1);
    byte[] bytes = baos.toByteArray();
    ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
    RecordReader r2 = (RecordReader) ois.readObject();

    File f = new ClassPathResource("iris_tab_delim.txt").getFile();

    r1.initialize(new FileSplit(f));
    r2.initialize(new FileSplit(f));

    int count = 0;
    while(r1.hasNext()){
        List<Writable> n1 = r1.next();
        List<Writable> n2 = r2.next();
        assertEquals(n1, n2);
        count++;
    }

    assertEquals(150-skipLines, count);
}
 
Example #18
Source File: LibSvmRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumLabelsException() throws Exception {
    LibSvmRecordReader rr = new LibSvmRecordReader();
    Configuration config = new Configuration();
    config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/inconsistentNumLabels.txt").getFile()));
    while (rr.hasNext())
        rr.next();
}
 
Example #19
Source File: SVMLightRecordWriterTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testNoLabel() throws Exception {
    Configuration configWriter = new Configuration();
    configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
    configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);

    Configuration configReader = new Configuration();
    configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
    configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);

    File inputFile = new ClassPathResource("svmlight/noLabels.txt").getFile();
    executeTest(configWriter, configReader, inputFile);
}
 
Example #20
Source File: SingleImageRecordTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testImageRecord() throws Exception {
    File f0 = new ClassPathResource("/testimages/class0/0.jpg").getFile();
    File f1 = new ClassPathResource("/testimages/class1/A.jpg").getFile();

    SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI());

    // need jackson test?
}
 
Example #21
Source File: FileRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testMeta() throws Exception {
    FileRecordReader rr = new FileRecordReader();


    URI[] arr = new URI[3];
    arr[0] = new ClassPathResource("csvsequence_0.txt").getFile().toURI();
    arr[1] = new ClassPathResource("csvsequence_1.txt").getFile().toURI();
    arr[2] = new ClassPathResource("csvsequence_2.txt").getFile().toURI();

    InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
    rr.initialize(is);

    List<List<Writable>> out = new ArrayList<>();
    while (rr.hasNext()) {
        out.add(rr.next());
    }

    assertEquals(3, out.size());

    rr.reset();
    List<List<Writable>> out2 = new ArrayList<>();
    List<Record> out3 = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    int count = 0;
    while (rr.hasNext()) {
        Record r = rr.nextRecord();
        out2.add(r.getRecord());
        out3.add(r);
        meta.add(r.getMetaData());

        assertEquals(arr[count++], r.getMetaData().getURI());
    }

    assertEquals(out, out2);
    List<Record> fromMeta = rr.loadFromMetaData(meta);
    assertEquals(out3, fromMeta);
}
 
Example #22
Source File: TensorFlowImportTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testCrash_119_reduce_dim_true() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream());
    assertNotNull(tg);

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
}
 
Example #23
Source File: SVMLightRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test(expected = NoSuchElementException.class)
public void testNoSuchElementException() throws Exception {
    SVMLightRecordReader rr = new SVMLightRecordReader();
    Configuration config = new Configuration();
    config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile()));
    while (rr.hasNext())
        rr.next();
    rr.next();
}
 
Example #24
Source File: LibSvmRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
    LibSvmRecordReader rr = new LibSvmRecordReader();
    Configuration config = new Configuration();
    config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
    config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
    config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
    config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/zeroIndexLabel.txt").getFile()));
    rr.next();
}
 
Example #25
Source File: CodecReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Ignore
@Test
public void testNativeCodecReaderMeta() throws Exception {
    File file = new ClassPathResource("fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new NativeCodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");
    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> record = reader.sequenceRecord();
    assertEquals(500, record.size()); //500 frames

    reader.reset();
    SequenceRecord seqR = reader.nextSequence();
    assertEquals(record, seqR.getSequenceRecord());
    RecordMetaData meta = seqR.getMetaData();
    //        System.out.println(meta);
    assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4"));

    SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta);
    assertEquals(seqR, fromMeta);
}
 
Example #26
Source File: CodecReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testViaDataInputStream() throws Exception {

    File file = new ClassPathResource("fire_lowres.mp4").getFile();
    SequenceRecordReader reader = new CodecRecordReader();
    Configuration conf = new Configuration();
    conf.set(CodecRecordReader.RAVEL, "true");
    conf.set(CodecRecordReader.START_FRAME, "160");
    conf.set(CodecRecordReader.TOTAL_FRAMES, "500");
    conf.set(CodecRecordReader.ROWS, "80");
    conf.set(CodecRecordReader.COLUMNS, "46");

    Configuration conf2 = new Configuration(conf);

    reader.initialize(new FileSplit(file));
    reader.setConf(conf);
    assertTrue(reader.hasNext());
    List<List<Writable>> expected = reader.sequenceRecord();


    SequenceRecordReader reader2 = new CodecRecordReader();
    reader2.setConf(conf2);

    DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file));
    List<List<Writable>> actual = reader2.sequenceRecord(null, dataInputStream);

    assertEquals(expected, actual);
}
 
Example #27
Source File: DataSetIteratorHelper.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
private static DataSetIteratorSplitter createDataSetSplitter() throws IOException, InterruptedException {
    final RecordReader recordReader = DataSetIteratorHelper.generateReader(new ClassPathResource("Churn_Modelling.csv").getFile());
    final DataSetIterator dataSetIterator = new RecordReaderDataSetIterator.Builder(recordReader,batchSize)
            .classification(labelIndex,numClasses)
            .build();
    final DataNormalization dataNormalization = new NormalizerStandardize();
    dataNormalization.fit(dataSetIterator);
    dataSetIterator.setPreProcessor(dataNormalization);
    final DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,1250,0.8);
    return dataSetIteratorSplitter;
}
 
Example #28
Source File: SVMLightRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test(expected = IndexOutOfBoundsException.class)
public void testLabelIndexExceedsNumLabels() throws Exception {
    SVMLightRecordReader rr = new SVMLightRecordReader();
    Configuration config = new Configuration();
    config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
    config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile()));
    rr.next();
}
 
Example #29
Source File: CSVRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test
public void testPipesAsSplit() throws Exception {

    CSVRecordReader reader = new CSVRecordReader(0, '|');
    reader.initialize(new FileSplit(new ClassPathResource("issue414.csv").getFile()));
    int lineidx = 0;
    List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
    while (reader.hasNext()) {
        List<Writable> list = new ArrayList<>(reader.next());

        assertEquals(10, list.size());
        assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
        lineidx++;
    }
}
 
Example #30
Source File: SVMLightRecordReaderTest.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
    SVMLightRecordReader rr = new SVMLightRecordReader();
    Configuration config = new Configuration();
    config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
    config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
    config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
    rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/zeroIndexLabel.txt").getFile()));
    rr.next();
}