/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class RandomDataSetIteratorTest extends BaseDL4JTest { @Test public void testDSI(){ DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM, RandomDataSetIterator.Values.ONE_HOT); int count = 0; while(iter.hasNext()){ count++; DataSet ds = iter.next(); assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape()); assertArrayEquals(new long[]{3,5}, ds.getLabels().shape()); assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0); assertEquals(Nd4j.ones(3), ds.getLabels().sum(1)); } assertEquals(5, count); } @Test public void testMDSI(){ Nd4j.getRandom().setSeed(12345); MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5) .addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100) .addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY) .addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS) .build(); int count = 0; while(iter.hasNext()){ count++; MultiDataSet mds = iter.next(); assertEquals(2, mds.numFeatureArrays()); assertEquals(1, mds.numLabelsArrays()); assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape()); assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape()); assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape()); assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 && mds.getFeatures(0).maxNumber().doubleValue() > 2.0); assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0); assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0); } assertEquals(5, count); } }