package org.nd4j.linalg.dataset.api.preprocessor; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { public CompositeDataSetPreProcessorTest(Nd4jBackend backend) { super(backend); } @Override public char ordering() { return 'c'; } @Test(expected = NullPointerException.class) public void when_preConditionsIsNull_expect_NullPointerException() { // Assemble CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); // Act sut.preProcess(null); } @Test public void when_dataSetIsEmpty_expect_emptyDataSet() { // Assemble CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); DataSet ds = new DataSet(null, null); // Act sut.preProcess(ds); // Assert assertTrue(ds.isEmpty()); } @Test public void when_notStoppingOnEmptyDataSet_expect_allPreProcessorsCalled() { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(preProcessor1, preProcessor2); DataSet ds = new DataSet(Nd4j.rand(2, 2), null); // Act sut.preProcess(ds); // Assert assertTrue(preProcessor1.hasBeenCalled); assertTrue(preProcessor2.hasBeenCalled); } @Test public void when_stoppingOnEmptyDataSetAndFirstPreProcessorClearDS_expect_firstPreProcessorsCalled() { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2); DataSet ds = new DataSet(Nd4j.rand(2, 2), null); // Act sut.preProcess(ds); // Assert assertTrue(preProcessor1.hasBeenCalled); assertFalse(preProcessor2.hasBeenCalled); } @Test public void when_stoppingOnEmptyDataSet_expect_firstPreProcessorsCalled() { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(false); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(false); CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2); DataSet ds = new DataSet(Nd4j.rand(2, 2), null); // Act sut.preProcess(ds); // Assert assertTrue(preProcessor1.hasBeenCalled); assertTrue(preProcessor2.hasBeenCalled); } public static class TestDataSetPreProcessor implements DataSetPreProcessor { private final boolean clearDataSet; public boolean hasBeenCalled; public TestDataSetPreProcessor(boolean clearDataSet) { this.clearDataSet = clearDataSet; } @Override public void preProcess(org.nd4j.linalg.dataset.api.DataSet dataSet) { hasBeenCalled = true; if(clearDataSet) { dataSet.setFeatures(null); } } } }