Java Code Examples for org.nd4j.linalg.dataset.DataSet#merge()

The following examples show how to use org.nd4j.linalg.dataset.DataSet#merge() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ListDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public DataSet next(int num) {
    int end = curr + num;

    List<DataSet> r = new ArrayList<>();
    if (end >= list.size())
        end = list.size();
    for (; curr < end; curr++) {
        r.add(list.get(curr));
    }

    DataSet d = DataSet.merge(r);
    if (preProcessor != null) {
        if (!d.isPreProcessed()) {
            preProcessor.preProcess(d);
            d.markAsPreProcessed();
        }
    }
    return d;
}
 
Example 2
Source File: TestDataSetIterator.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Override
public DataSet next(int num) {
    int end = curr + num;

    List<DataSet> r = new ArrayList<>();
    if (end >= list.size())
        end = list.size();
    for (; curr < end; curr++) {
        r.add(list.get(curr));
    }

    DataSet d = DataSet.merge(r);
    if (preProcessor != null)
        preProcessor.preProcess(d);
    return d;
}
 
Example 3
Source File: TestDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public DataSet next(int num) {
    int end = curr + num;

    List<DataSet> r = new ArrayList<>();
    if (end >= list.size())
        end = list.size();
    for (; curr < end; curr++) {
        r.add(list.get(curr));
    }

    DataSet d = DataSet.merge(r);
    if (preProcessor != null)
        preProcessor.preProcess(d);
    return d;
}
 
Example 4
Source File: SpecialTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testScalarShuffle2() {
    List<DataSet> listData = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        INDArray features = Nd4j.ones(14, 25);
        INDArray label = Nd4j.create(14, 50);
        DataSet dataset = new DataSet(features, label);
        listData.add(dataset);
    }
    DataSet data = DataSet.merge(listData);
    data.shuffle();
}
 
Example 5
Source File: BaseSparkTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected JavaRDD<DataSet> getBasicSparkDataSet(int nRows, INDArray input, INDArray labels) {
    List<DataSet> list = new ArrayList<>();
    for (int i = 0; i < nRows; i++) {
        INDArray inRow = input.getRow(i, true).dup();
        INDArray outRow = labels.getRow(i, true).dup();

        DataSet ds = new DataSet(inRow, outRow);
        list.add(ds);
    }
    list.iterator();

    data = DataSet.merge(list);
    return sc.parallelize(list);
}
 
Example 6
Source File: BaseSparkTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected JavaRDD<DataSet> getBasicSparkDataSet(int nRows, INDArray input, INDArray labels) {
    List<DataSet> list = new ArrayList<>();
    for (int i = 0; i < nRows; i++) {
        INDArray inRow = input.getRow(i, true).dup();
        INDArray outRow = labels.getRow(i, true).dup();

        DataSet ds = new DataSet(inRow, outRow);
        list.add(ds);
    }
    list.iterator();

    data = DataSet.merge(list);
    return sc.parallelize(list);
}
 
Example 7
Source File: WSTestDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public DataSet next(int i) {
    final LinkedList<DataSet> parts = new LinkedList<>();
    while(parts.size() < i && hasNext()){
        parts.add(nextOne());
    }
    cursor++;
    return DataSet.merge(parts);
}
 
Example 8
Source File: CifarLoader.java    From DataVec with Apache License 2.0 5 votes vote down vote up
public DataSet next(int batchSize, int exampleNum) {
    List<DataSet> temp = new ArrayList<>();
    DataSet result;
    if (cifarProcessedFilesExists() && useSpecialPreProcessCifar) {
        if (exampleNum == 0 || ((exampleNum / fileNum) == numToConvertDS && train)) {
            fileNum++;
            if (train)
                loadDS.load(new File(trainFilesSerialized + fileNum + ".ser"));
            loadDS.load(new File(testFilesSerialized));
            // Shuffle all examples in file before batching happens also for each reset
            if (shuffle && batchSize > 1)
                loadDS.shuffle(seed);
            loadDSIndex = 0;
            //          inputBatched = loadDS.batchBy(batchSize);
        }
        // TODO loading full train dataset when using cuda causes memory error - find way to load into list off gpu
        //            result = inputBatched.get(batchNum);
        for (int i = 0; i < batchSize; i++) {
            if (loadDS.get(loadDSIndex) != null)
                temp.add(loadDS.get(loadDSIndex));
            else
                break;
            loadDSIndex++;
        }
        if (temp.size() > 1)
            result = DataSet.merge(temp);
        else
            result = temp.get(0);
    } else {
        result = convertDataSet(batchSize);
    }
    return result;
}
 
Example 9
Source File: SpecialTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testScalarShuffle1() {
    List<DataSet> listData = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        INDArray features = Nd4j.ones(25, 25);
        INDArray label = Nd4j.create(new float[] {1}, new int[] {1});
        DataSet dataset = new DataSet(features, label);
        listData.add(dataset);
    }
    DataSet data = DataSet.merge(listData);
    data.shuffle();
}
 
Example 10
Source File: LoneTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void maskWhenMerge() {
    DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
    DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
    List<DataSet> dataSetList = new ArrayList<DataSet>();
    dataSetList.add(dsA);
    dataSetList.add(dsB);
    DataSet fullDataSet = DataSet.merge(dataSetList);
    assertTrue(fullDataSet.getFeaturesMaskArray() != null);

    DataSet fullDataSetCopy = fullDataSet.copy();
    assertTrue(fullDataSetCopy.getFeaturesMaskArray() != null);

}
 
Example 11
Source File: UnderSamplingPreProcessorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public DataSet knownDistVariedDataSet(float[] dist, boolean twoClass) {
    //construct a dataset with known distribution of minority class and varying time steps
    DataSet batchATimeSteps = makeDataSetSameL(minibatchSize, shortSeq, dist, twoClass);
    DataSet batchBTimeSteps = makeDataSetSameL(minibatchSize, longSeq, dist, twoClass);
    List<DataSet> listofbatches = new ArrayList<>();
    listofbatches.add(batchATimeSteps);
    listofbatches.add(batchBTimeSteps);
    return DataSet.merge(listofbatches);
}
 
Example 12
Source File: CifarLoader.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public DataSet next(int batchSize, int exampleNum) {
    List<DataSet> temp = new ArrayList<>();
    DataSet result;
    if (cifarProcessedFilesExists() && useSpecialPreProcessCifar) {
        if (exampleNum == 0 || ((exampleNum / fileNum) == numToConvertDS && train)) {
            fileNum++;
            if (train)
                loadDS.load(new File(trainFilesSerialized + fileNum + ".ser"));
            loadDS.load(new File(testFilesSerialized));
            // Shuffle all examples in file before batching happens also for each reset
            if (shuffle && batchSize > 1)
                loadDS.shuffle(seed);
            loadDSIndex = 0;
            //          inputBatched = loadDS.batchBy(batchSize);
        }
        // TODO loading full train dataset when using cuda causes memory error - find way to load into list off gpu
        //            result = inputBatched.get(batchNum);
        for (int i = 0; i < batchSize; i++) {
            if (loadDS.get(loadDSIndex) != null)
                temp.add(loadDS.get(loadDSIndex));
            else
                break;
            loadDSIndex++;
        }
        if (temp.size() > 1)
            result = DataSet.merge(temp);
        else
            result = temp.get(0);
    } else {
        result = convertDataSet(batchSize);
    }
    return result;
}
 
Example 13
Source File: SpecialTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testScalarShuffle1() throws Exception {
    List<DataSet> listData = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        INDArray features = Nd4j.ones(25, 25);
        INDArray label = Nd4j.create(new float[] {1}, new int[] {1});
        DataSet dataset = new DataSet(features, label);
        listData.add(dataset);
    }
    DataSet data = DataSet.merge(listData);
    data.shuffle();
}
 
Example 14
Source File: LoneTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void maskWhenMerge() {
    DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
    DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
    List<DataSet> dataSetList = new ArrayList<DataSet>();
    dataSetList.add(dsA);
    dataSetList.add(dsB);
    DataSet fullDataSet = DataSet.merge(dataSetList);
    assertTrue(fullDataSet.getFeaturesMaskArray() != null);

    DataSet fullDataSetCopy = fullDataSet.copy();
    assertTrue(fullDataSetCopy.getFeaturesMaskArray() != null);

}
 
Example 15
Source File: UnderSamplingPreProcessorTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
public DataSet knownDistVariedDataSet(float[] dist, boolean twoClass) {
    //construct a dataset with known distribution of minority class and varying time steps
    DataSet batchATimeSteps = makeDataSetSameL(minibatchSize, shortSeq, dist, twoClass);
    DataSet batchBTimeSteps = makeDataSetSameL(minibatchSize, longSeq, dist, twoClass);
    List<DataSet> listofbatches = new ArrayList<>();
    listofbatches.add(batchATimeSteps);
    listofbatches.add(batchBTimeSteps);
    return DataSet.merge(listofbatches);
}
 
Example 16
Source File: TestCompareParameterAveragingSparkVsSingleMachine.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private DataSet getOneDataSet(int totalExamples, int seed) {
    return DataSet.merge(getOneDataSetAsIndividalExamples(totalExamples, seed));
}
 
Example 17
Source File: FileDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
protected DataSet merge(List<DataSet> toMerge) {
    return DataSet.merge(toMerge);
}
 
Example 18
Source File: GradientSharingTrainingTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
    int batch = 3;

    File temp = testDir.newFolder();
    DataSet ds = new IrisDataSetIterator(150, 150).next();
    List<DataSet> list = ds.asList();
    Collections.shuffle(list, new Random(12345));
    int pos = 0;
    int dsCount = 0;
    while (pos < list.size()) {
        List<DataSet> l2 = new ArrayList<>();
        for (int i = 0; i < 3 && pos < list.size(); i++) {
            l2.add(list.get(pos++));
        }
        DataSet d = DataSet.merge(l2);
        File f = new File(temp, dsCount++ + ".bin");
        d.save(f);
    }

    INDArray last = null;
    INDArray lastDup = null;
    for (int i = 0; i < 2; i++) {
        System.out.println("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||");
        log.info("Starting: {}", i);

        MultiLayerConfiguration conf;
        if (i == 0) {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        } else {
            conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .seed(12345)
                    .list()
                    .layer(new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build())
                    .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
        }
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();


        //TODO this probably won't work everywhere...
        String controller = Inet4Address.getLocalHost().getHostAddress();
        String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";

        VoidConfiguration voidConfiguration = VoidConfiguration.builder()
                .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
                .networkMask(networkMask) // Local network mask
                .controllerAddress(controller)
                .build();
        TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new FixedThresholdAlgorithm(1e-4), batch)
                .rngSeed(12345)
                .collectTrainingStats(false)
                .batchSizePerWorker(batch) // Minibatch size for each worker
                .workersPerNode(2) // Workers per node
                .build();


        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, net, tm);

        //System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));

        String fitPath = "file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/");
        INDArray paramsBefore = net.params().dup();
        for( int j=0; j<3; j++ ) {
            sparkNet.fit(fitPath);
        }

        INDArray paramsAfter = net.params();
        assertNotEquals(paramsBefore, paramsAfter);

        //Also check we don't have any issues
        if(i == 0) {
            last = sparkNet.getNetwork().params();
            lastDup = last.dup();
        } else {
            assertEquals(lastDup, last);
        }
    }
}
 
Example 19
Source File: ScoreExamplesWithKeyFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Iterator<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyIterator();
    }

    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);

    List<Tuple2<K, Double>> ret = new ArrayList<>();

    List<DataSet> collect = new ArrayList<>(batchSize);
    List<K> collectKey = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        collectKey.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            Tuple2<K, DataSet> t2 = iterator.next();
            DataSet ds = t2._2();
            int n = ds.numExamples();
            if (n != 1)
                throw new IllegalStateException("Cannot score examples with one key per data set if "
                                + "data set contains more than 1 example (numExamples: " + n + ")");
            collect.add(ds);
            collectKey.add(t2._1());
            nExamples += n;
        }
        totalCount += nExamples;

        DataSet data = DataSet.merge(collect);


        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();

        for (int i = 0; i < doubleScores.length; i++) {
            ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
        }
    }

    Nd4j.getExecutioner().commit();

    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }

    return ret.iterator();
}
 
Example 20
Source File: ScoreExamplesFunction.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Iterator<Double> call(Iterator<DataSet> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyIterator();
    }

    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);

    List<Double> ret = new ArrayList<>();

    List<DataSet> collect = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            DataSet ds = iterator.next();
            int n = ds.numExamples();
            collect.add(ds);
            nExamples += n;
        }
        totalCount += nExamples;

        DataSet data = DataSet.merge(collect);


        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();

        for (double doubleScore : doubleScores) {
            ret.add(doubleScore);
        }
    }

    Nd4j.getExecutioner().commit();

    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }

    return ret.iterator();
}