Java Code Examples for org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#resetSupported()

The following examples show how to use org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator#resetSupported() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AsyncMultiDataSetIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                                 boolean useWorkspace, DataSetCallback callback, Integer deviceId) {

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported() && !iterator.hasNext())
        this.backedIterator.reset();

    this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId);

    thread.setDaemon(true);
    thread.start();
}
 
Example 2
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 *
 * @param baseIterator
 * @param totalBatches - total number of batches in underlying iterator. this value will be used to determine number of test/train batches
 * @param ratio - this value will be used as splitter. should be between in range of 0.0 > X < 1.0. I.e. if value 0.7 is provided, then 70% of total examples will be used for training, and 30% of total examples will be used for testing
 */
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double ratio) {
    if (!(ratio > 0.0 && ratio < 1.0))
        throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");

    if (totalBatches < 0)
        throw new ND4JIllegalStateException("totalExamples number should be positive value");

    if (!baseIterator.resetSupported())
        throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


    this.backedIterator = baseIterator;
    this.totalExamples = totalBatches;
    this.ratio = ratio;
    this.numTrain = (long) (totalExamples * ratio);
    this.numTest = totalExamples - numTrain;
    this.ratios = null;
    this.numArbitrarySets = 0;
    this.splits = null;

    log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
 
Example 3
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) {

        int totalBatches = 0;
        for (val v:splits)
            totalBatches += v;

        if (totalBatches < 0)
            throw new ND4JIllegalStateException("totalExamples number should be positive value");

        if (!baseIterator.resetSupported())
            throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


        this.backedIterator = baseIterator;
        this.totalExamples = totalBatches;
        this.ratio = 0.0;
        this.numTrain = (long) (totalExamples * ratio);
        this.numTest = totalExamples - numTrain;
        this.ratios = null;
        this.numArbitrarySets = splits.length;
        this.splits = splits;

        log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
    }
 
Example 4
Source File: SparkAMDSI.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
                boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.callback = callback;
    this.buffer = queue;
    this.backedIterator = iterator;
    this.useWorkspaces = useWorkspace;
    this.prefetchSize = queueSize;
    this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString();
    this.deviceId = deviceId;

    if (iterator.resetSupported())
        this.backedIterator.reset();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());

    context = TaskContext.get();

    thread.setDaemon(true);
    thread.start();
}
 
Example 5
Source File: MultiDataSetIteratorSplitter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
    for (double ratio : ratios) {
        if (!(ratio > 0.0 && ratio < 1.0))
            throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
    }

    if (totalBatches < 0)
        throw new ND4JIllegalStateException("totalExamples number should be positive value");

    if (!baseIterator.resetSupported())
        throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


    this.backedIterator = baseIterator;
    this.totalExamples = totalBatches;
    this.ratio = 0.0;
    this.numTrain = (long) (totalExamples * ratio);
    this.numTest = totalExamples - numTrain;
    this.ratios = null;
    this.numArbitrarySets = ratios.length;

    this.splits = new int[this.ratios.length];
    for (int i = 0; i < this.splits.length; ++i) {
        this.splits[i] = (int)(totalExamples * ratios[i]);
    }

    log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}