package com.mccorby.federatedlearning.features.diabetes.datasource; import com.mccorby.federatedlearning.core.domain.model.FederatedDataSet; import com.mccorby.federatedlearning.core.repository.FederatedDataSource; import com.mccorby.federatedlearning.datasource.FederatedDataSetImpl; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.InputStreamInputSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import java.io.IOException; import java.io.InputStream; public class DiabetesFileDataSource implements FederatedDataSource { private static final String TAG = DiabetesFileDataSource.class.getSimpleName(); private InputStream dataFile; private int batchSize; private DataSet trainingData; private DataSet testData; public DiabetesFileDataSource(InputStream dataFile, int batchSize) { this.dataFile = dataFile; this.batchSize = batchSize; } private void createDataSource() throws IOException, InterruptedException { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; String delimiter = ","; RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter); recordReader.initialize(new InputStreamInputSplit(dataFile)); //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 11; DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); DataSet allData = iterator.next(); SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training trainingData = testAndTrain.getTrain(); testData = testAndTrain.getTest(); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data normalizer.transform(trainingData); //Apply normalization to the training data normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set } @Override public FederatedDataSet getTrainingData() { if (trainingData == null) { try { createDataSource(); } catch (IOException | InterruptedException e) { e.printStackTrace(); } } return new FederatedDataSetImpl(trainingData); } @Override public FederatedDataSet getTestData() { if (testData == null) { try { createDataSource(); } catch (IOException | InterruptedException e) { e.printStackTrace(); } } return new FederatedDataSetImpl(testData); } @Override public FederatedDataSet getCrossValidationData() { return null; } }