package org.datavec.spark.transform; import org.datavec.image.transform.ImageTransformProcess; import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchImageRecord; import org.datavec.spark.transform.model.SingleImageRecord; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.serde.base64.Nd4jBase64; import java.io.File; import static org.junit.Assert.assertEquals; /** * Created by kepricon on 17. 5. 24. */ public class ImageSparkTransformTest { @Test public void testSingleImageSparkTransform() throws Exception { int seed = 12345; File f1 = new ClassPathResource("/testimages/class1/A.jpg").getFile(); SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI()); ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) .scaleImageTransform(10).cropImageTransform(5).build(); ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); System.out.println("Base 64ed array " + fromBase64); assertEquals(1, fromBase64.size(0)); } @Test public void testBatchImageSparkTransform() throws Exception { int seed = 12345; File f0 = new ClassPathResource("/testimages/class1/A.jpg").getFile(); File f1 = new ClassPathResource("/testimages/class1/B.png").getFile(); File f2 = new ClassPathResource("/testimages/class1/C.jpg").getFile(); BatchImageRecord batch = new BatchImageRecord(); batch.add(f0.toURI()); batch.add(f1.toURI()); batch.add(f2.toURI()); ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) .scaleImageTransform(10).cropImageTransform(5).build(); ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); Base64NDArrayBody body = imgSparkTransform.toArray(batch); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); System.out.println("Base 64ed array " + fromBase64); assertEquals(3, fromBase64.size(0)); } }