org.nd4j.linalg.dataset.api.MultiDataSet Java Examples

The following examples show how to use org.nd4j.linalg.dataset.api.MultiDataSet. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TrainUtil.java    From FancyBing with GNU General Public License v3.0 6 votes vote down vote up
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) {
	log.info("Evaluate model....");
    Evaluation clsEval = new Evaluation(createLabels(outputNum), topN);
    RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1);
    int count = 0;
    
    long begin = 0;
    long consume = 0;
    while(testData.hasNext()){
    	MultiDataSet ds = testData.next();
    	begin = System.nanoTime();
    	INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures());
    	consume += System.nanoTime() - begin;
        clsEval.eval(ds.getLabels(0), output[0]);
        valueRegEval1.eval(ds.getLabels(1), output[1]);
        count++;
    }
    String stats = clsEval.stats();
    int pos = stats.indexOf("===");
    stats = "\n" + stats.substring(pos);
    log.info(stats);
    log.info(valueRegEval1.stats());
    testData.reset();
    log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000));
	return clsEval.accuracy();
}
 
Example #2
Source File: BaseSparkEarlyStoppingTrainer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected BaseSparkEarlyStoppingTrainer(JavaSparkContext sc, EarlyStoppingConfiguration<T> esConfig, T net,
                JavaRDD<DataSet> train, JavaRDD<MultiDataSet> trainMulti, EarlyStoppingListener<T> listener) {
    if ((esConfig.getEpochTerminationConditions() == null || esConfig.getEpochTerminationConditions().isEmpty())
                    && (esConfig.getIterationTerminationConditions() == null
                                    || esConfig.getIterationTerminationConditions().isEmpty())) {
        throw new IllegalArgumentException(
                        "Cannot conduct early stopping without a termination condition (both Iteration "
                                        + "and Epoch termination conditions are null/empty)");
    }

    this.sc = sc;
    this.esConfig = esConfig;
    this.net = net;
    this.train = train;
    this.trainMulti = trainMulti;
    this.listener = listener;
}
 
Example #3
Source File: TestBertIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testSentencePairFeaturizer() throws IOException {
    int minibatchSize = 2;
    TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
    BertIterator b = BertIterator.builder()
            .tokenizer(testPairHelper.getTokenizer())
            .minibatchSize(minibatchSize)
            .padMinibatches(true)
            .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
            .vocabMap(testPairHelper.getTokenizer().getVocab())
            .task(BertIterator.Task.SEQ_CLASSIFICATION)
            .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
            .sentencePairProvider(testPairHelper.getPairSentenceProvider())
            .prependToken("[CLS]")
            .appendToken("[SEP]")
            .build();
    MultiDataSet mds = b.next();
    INDArray[] featuresArr = mds.getFeatures();
    INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();

    Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs());
    assertEquals(p.getFirst().length, 2);
    assertEquals(featuresArr[0], p.getFirst()[0]);
    assertEquals(featuresArr[1], p.getFirst()[1]);
    assertEquals(featuresMaskArr[0], p.getSecond()[0]);
}
 
Example #4
Source File: SparkAMDSI.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported())
        this.backedIterator.reset();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());

    context = TaskContext.get();

    thread.setDaemon(true);
    thread.start();
}
 
Example #5
Source File: BatchAndExportMultiDataSetsFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private String export(MultiDataSet dataSet, int partitionIdx, int outputCount) throws Exception {
    String filename = "mds_" + partitionIdx + jvmuid + "_" + outputCount + ".bin";

    URI uri = new URI(exportBaseDirectory
                    + (exportBaseDirectory.endsWith("/") || exportBaseDirectory.endsWith("\\") ? "" : "/")
                    + filename);

    Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration();

    FileSystem file = FileSystem.get(uri, c);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        dataSet.save(out);
    }

    return uri.toString();
}
 
Example #6
Source File: TestMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public MultiDataSet next(int num) {
    int end = curr + num;

    List<MultiDataSet> r = new ArrayList<>();
    if (end >= list.size()) {
        end = list.size();
    }
    for (; curr < end; curr++) {
        r.add(list.get(curr));
    }

    MultiDataSet d = org.nd4j.linalg.dataset.MultiDataSet.merge(r);
    if (preProcessor != null) {
        preProcessor.preProcess(d);
    }
    return d;
}
 
Example #7
Source File: IEvaluateMDSFlatMapFunction.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public Iterator<T[]> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.emptyIterator();
    }

    if (!dataSetIterator.hasNext()) {
        return Collections.emptyIterator();
    }

    Future<IEvaluation[]> f = EvaluationRunner.getInstance().execute(
            evaluations, evalNumWorkers, evalBatchSize, null, dataSetIterator, true, json, params);

    IEvaluation[] result = f.get();
    if(result == null){
        return Collections.emptyIterator();
    } else {
        return Collections.singletonList((T[])result).iterator();
    }
}
 
Example #8
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Fit an iterator
 *
 * @param iterator for the data to iterate over
 */
public void fit(@NonNull MultiDataSetIterator iterator) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    iterator.reset();
    while (iterator.hasNext()) {
        MultiDataSet next = iterator.next();
        fitPartial(next, featureNormBuilders, labelNormBuilders);
    }

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example #9
Source File: AsyncMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                                 boolean useWorkspace, DataSetCallback callback, Integer deviceId) {

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported() && !iterator.hasNext())
        this.backedIterator.reset();

    this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId);

    thread.setDaemon(true);
    thread.start();
}
 
Example #10
Source File: TestComputationGraphNetwork.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCompGraphDropoutOutputLayers(){
    //https://github.com/deeplearning4j/deeplearning4j/issues/6326
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .dropOut(0.8)
            .graphBuilder()
            .addInputs("in1", "in2")
            .addVertex("merge", new MergeVertex(), "in1", "in2")
            .addLayer("lstm",
                    new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder()
                            .nIn(10).nOut(5)
                            .activation(Activation.TANH)
                            .dropOut(new GaussianNoise(0.05))
                            .build())
                    ,"merge")
            .addLayer("out1",
                    new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
                            .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
                            .nOut(6).build(),
                    "lstm")
            .addLayer("out2",
                    new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
                            .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
                            .nOut(4).build(),
                    "lstm")
            .setOutputs("out1", "out2").build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    INDArray[] features = new INDArray[]{Nd4j.create(1, 5, 5), Nd4j.create(1, 5, 5)};
    INDArray[] labels = new INDArray[]{Nd4j.create(1, 6, 5), Nd4j.create(1, 4, 5)};
    MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(features, labels);
    net.fit(mds);
}
 
Example #11
Source File: ImageMultiPreProcessingScaler.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(MultiDataSet multiDataSet) {
    for( int i=0; i<featureIndices.length; i++ ){
        INDArray f = multiDataSet.getFeatures(featureIndices[i]);
        f.divi(this.maxPixelVal); //Scaled to 0->1
        if (this.maxRange - this.minRange != 1)
            f.muli(this.maxRange - this.minRange); //Scaled to minRange -> maxRange
        if (this.minRange != 0)
            f.addi(this.minRange); //Offset by minRange
    }
}
 
Example #12
Source File: AbstractMultiDataSetNormalizer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Fit a MultiDataSet (only compute based on the statistics from this {@link MultiDataSet})
 *
 * @param dataSet the dataset to compute on
 */
public void fit(@NonNull MultiDataSet dataSet) {
    List<S.Builder> featureNormBuilders = new ArrayList<>();
    List<S.Builder> labelNormBuilders = new ArrayList<>();

    fitPartial(dataSet, featureNormBuilders, labelNormBuilders);

    featureStats = buildList(featureNormBuilders);
    if (isFitLabel()) {
        labelStats = buildList(labelNormBuilders);
    }
}
 
Example #13
Source File: EvalTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMultiOutputEvalCG(){
    //Simple sanity check on evaluation

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .graphBuilder()
            .addInputs("in")
            .layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in")
            .layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
            .layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
            .layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
            .layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2")
            .setOutputs("out1", "out2")
            .build();

    ComputationGraph cg = new ComputationGraph(conf);
    cg.init();

    org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(
            new INDArray[]{Nd4j.create(10, 1, 10)},
            new INDArray[]{Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10)});

    Map<Integer,org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>();
    m.put(0, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});
    m.put(1, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});

    cg.evaluate(new SingletonMultiDataSetIterator(mds), m);
}
 
Example #14
Source File: EarlyTerminationMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public MultiDataSet next() {
    if (minibatchCount < terminationPoint) {
        minibatchCount++;
        return underlyingIterator.next();
    } else {
        throw new RuntimeException("Calls to next have exceeded the allotted number of minibatches.");
    }
}
 
Example #15
Source File: VasttextDataIterator.java    From scava with Eclipse Public License 2.0 5 votes vote down vote up
@Override
public MultiDataSet next(int num) {
	if (!hasNext())
		throw new NoSuchElementException("No next elements");

	// First: load the next values from the RR / SeqRRs
	Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
	List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>()
			: null);

	for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
		RecordReader rr = entry.getValue();
		// Standard case
			List<List<Writable>> writables = new ArrayList<>(Math.min(num, 100000)); // Min op: in case user puts
																						// batch size >> amount of
																						// data
			for (int i = 0; i < num && rr.hasNext(); i++) 
			{
				List<Writable> record;
				if (collectMetaData) {
					Record r = rr.nextRecord();
					record = r.getRecord();
					if (nextMetas.size() <= i) {
						nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
					}
					RecordMetaDataComposableMap map = nextMetas.get(i);
					map.getMeta().put(entry.getKey(), r.getMetaData());
				} else {
					record = rr.next();
				}
				writables.add(record);
			}

			nextRRVals.put(entry.getKey(), writables);
	}

	return nextMultiDataSet(nextRRVals, nextMetas);
}
 
Example #16
Source File: IteratorUtils.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Apply a single reader {@link RecordReaderMultiDataSetIterator} to a {@code JavaRDD<List<Writable>>}.
 * <b>NOTE</b>: The RecordReaderMultiDataSetIterator <it>must</it> use {@link SparkSourceDummyReader} in place of
 * "real" RecordReader instances
 *
 * @param rdd      RDD with writables
 * @param iterator RecordReaderMultiDataSetIterator with {@link SparkSourceDummyReader} readers
 */
public static JavaRDD<MultiDataSet> mapRRMDSI(JavaRDD<List<Writable>> rdd, RecordReaderMultiDataSetIterator iterator){
    checkIterator(iterator, 1, 0);
    return mapRRMDSIRecords(rdd.map(new Function<List<Writable>,DataVecRecords>(){
        @Override
        public DataVecRecords call(List<Writable> v1) throws Exception {
            return new DataVecRecords(Collections.singletonList(v1), null);
        }
    }), iterator);
}
 
Example #17
Source File: VasttextDataIterator.java    From scava with Eclipse Public License 2.0 5 votes vote down vote up
/**
 * Load a multiple sequence examples to a DataSet, using the provided
 * RecordMetaData instances.
 *
 * @param list
 *            List of RecordMetaData instances to load from. Should have been
 *            produced by the record reader provided to the
 *            SequenceRecordReaderDataSetIterator constructor
 * @return DataSet with the specified examples
 * @throws IOException
 *             If an error occurs during loading of the data
 */
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
	// First: load the next values from the RR / SeqRRs
	Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
	List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>()
			: null);

	for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
		RecordReader rr = entry.getValue();

		List<RecordMetaData> thisRRMeta = new ArrayList<>();
		for (RecordMetaData m : list) {
			RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
			thisRRMeta.add(m2.getMeta().get(entry.getKey()));
		}

		List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
		List<List<Writable>> writables = new ArrayList<>(list.size());
		for (Record r : fromMeta) {
			writables.add(r.getRecord());
		}

		nextRRVals.put(entry.getKey(), writables);
	}

	return nextMultiDataSet(nextRRVals, nextMetas);

}
 
Example #18
Source File: MultiNormalizerHybrid.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Fit a MultiDataSet (only compute based on the statistics from this dataset)
 *
 * @param dataSet the dataset to compute on
 */
@Override
public void fit(@NonNull MultiDataSet dataSet) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    fitPartial(dataSet, inputStatsBuilders, outputStatsBuilders);

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example #19
Source File: BaseTrainingMaster.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected JavaRDD<String> exportIfRequiredMDS(JavaSparkContext sc, JavaRDD<MultiDataSet> trainingData) {
    ExportSupport.assertExportSupported(sc);
    if (collectTrainingStats)
        stats.logExportStart();

    //Two possibilities here:
    // 1. We've seen this RDD before (i.e., multiple epochs training case)
    // 2. We have not seen this RDD before
    //    (a) And we haven't got any stored data -> simply export
    //    (b) And we previously exported some data from a different RDD -> delete the last data
    int currentRDDUid = trainingData.id(); //Id is a "A unique ID for this RDD (within its SparkContext)."

    String baseDir;
    if (lastExportedRDDId == Integer.MIN_VALUE) {
        //Haven't seen a RDD<DataSet> yet in this training master -> export data
        baseDir = exportMDS(trainingData);
    } else {
        if (lastExportedRDDId == currentRDDUid) {
            //Use the already-exported data again for another epoch
            baseDir = getBaseDirForRDD(trainingData);
        } else {
            //The new RDD is different to the last one
            // Clean up the data for the last one, and export
            deleteTempDir(sc, lastRDDExportPath);
            baseDir = exportMDS(trainingData);
        }
    }

    if (collectTrainingStats)
        stats.logExportEnd();

    return sc.textFile(baseDir + "paths/");
}
 
Example #20
Source File: DummyBlockMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public MultiDataSet[] next(int maxDatasets) {
    val list = new ArrayList<MultiDataSet>(maxDatasets);
    int cnt = 0;
    while (iterator.hasNext() && cnt < maxDatasets) {
        list.add(iterator.next());
        cnt++;
    }

    return list.toArray(new MultiDataSet[list.size()]);
}
 
Example #21
Source File: AsyncMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testVariableTimeSeries2() throws Exception {
    int numBatches = isIntegrationTests() ? 1000 : 100;
    int batchSize = isIntegrationTests() ? 32 : 8;
    int timeStepsMin = 10;
    int timeStepsMax = isIntegrationTests() ? 500 : 100;
    int valuesPerTimestep = isIntegrationTests() ? 128 : 16;

    val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);

    for (int e = 0; e < 10; e++) {
        iterator.reset();
        iterator.hasNext();
        val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);

        int cnt = 0;
        while (amdsi.hasNext()) {
            MultiDataSet mds = amdsi.next();


            //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
            assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
                    mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10);
            assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
                    mds.getLabels()[0].meanNumber().doubleValue(), 1e-10);
            assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
                    mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
            assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
                    mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10);

            cnt++;
        }
    }
}
 
Example #22
Source File: ParameterServerTrainer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long time) {
    // FIXME: this is wrong, and should be fixed

    if (getModel() instanceof ComputationGraph) {
        ComputationGraph computationGraph = (ComputationGraph) getModel();
        computationGraph.fit(dataSet);
    } else {
        throw new IllegalArgumentException("MultiLayerNetworks can't fit multi datasets");
    }

    log.info("Sending parameters");
    //send the updated params
    parameterServerClient.pushNDArray(getModel().params());
}
 
Example #23
Source File: MultiNormalizerHybrid.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Fit a MultiDataSet (only compute based on the statistics from this dataset)
 *
 * @param dataSet the dataset to compute on
 */
@Override
public void fit(@NonNull MultiDataSet dataSet) {
    Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>();
    Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>();

    fitPartial(dataSet, inputStatsBuilders, outputStatsBuilders);

    inputStats = buildAllStats(inputStatsBuilders);
    outputStats = buildAllStats(outputStatsBuilders);
}
 
Example #24
Source File: MultiDataSetIteratorAdapter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public MultiDataSet next(int i) {
    MultiDataSet mds = iter.next(i).toMultiDataSet();
    if (preProcessor != null)
        preProcessor.preProcess(mds);
    return mds;
}
 
Example #25
Source File: OpExecOrderListener.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
    String opName = op.getName();
    if(!opSet.contains(opName)){
        opNamesList.add(opName);
        opSet.add(opName);
    }
}
 
Example #26
Source File: SparkAMDSI.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue) {
    this(iterator, queueSize, queue, true);
}
 
Example #27
Source File: ImageMultiPreProcessingScaler.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void revert(MultiDataSet toRevert) {
    revertFeatures(toRevert.getFeatures(), toRevert.getFeaturesMaskArrays());
}
 
Example #28
Source File: TestBertIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSentencePairsSingle() throws IOException {
    boolean prependAppend;
    int numOfSentences;

    TestSentenceHelper testHelper = new TestSentenceHelper();
    int shortL = testHelper.getShortestL();
    int longL = testHelper.getLongestL();

    Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
    MultiDataSet fromPair, leftSide, rightSide;

    // check for pair max length exactly equal to sum of lengths - pop neither no padding
    // should be the same as hstack with segment ids 1 for second sentence
    prependAppend = true;
    numOfSentences = 1;
    multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences);
    fromPair = multiDataSetTriple.getFirst();
    leftSide = multiDataSetTriple.getSecond();
    rightSide = multiDataSetTriple.getThird();
    assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
    rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids
    assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
    assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));

    //check for pair max length greater than sum of lengths - pop neither with padding
    // features should be the same as hstack of shorter and longer padded with prepend/append
    // segment id should 1 only in the longer for part of the length of the sentence
    prependAppend = true;
    numOfSentences = 1;
    multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences);
    fromPair = multiDataSetTriple.getFirst();
    leftSide = multiDataSetTriple.getSecond();
    rightSide = multiDataSetTriple.getThird();
    assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
    rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
    assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
    assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));

    //check for pair max length less than shorter sentence - pop both
    //should be the same as hstack with segment ids 1 for second sentence if no prepend/append
    int maxL = 5;//checking odd
    numOfSentences = 3;
    prependAppend = false;
    multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences);
    fromPair = multiDataSetTriple.getFirst();
    leftSide = multiDataSetTriple.getSecond();
    rightSide = multiDataSetTriple.getThird();
    assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
    rightSide.getFeatures(1).addi(1);
    assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
    assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
}
 
Example #29
Source File: BaseEvaluationListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * See {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)}
 */
public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
        INDArray activation){
    //No op
}
 
Example #30
Source File: BaseListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
    //No op
}