Java Code Examples for org.nd4j.common.primitives.Pair#makePair()

The following examples show how to use org.nd4j.common.primitives.Pair#makePair() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<VocabWord, float[]> next() {
    try {
        String word = ReadHelper.readString(stream);
        VocabWord element = new VocabWord(1.0, word);
        element.setIndex(idxCounter.getAndIncrement());

        float[] vector = new float[vectorLength];
        for (int i = 0; i < vectorLength; i++) {
            vector[i] = ReadHelper.readFloat(stream);
        }

        return Pair.makePair(element, vector);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example 2
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public Pair<VocabWord, float[]> next() {

            String[] split = nextLine.split(" ");

            VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0]));
            word.setIndex(idxCounter.getAndIncrement());

            float[] vector = new float[split.length - 1];
            for (int i = 1; i < split.length; i++) {
                vector[i - 1] = Float.parseFloat(split[i]);
            }

            try {
                nextLine = reader.readLine();
            } catch (Exception e) {
                nextLine = null;
            }

            return Pair.makePair(word, vector);
        }
 
Example 3
Source File: ParagraphVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<String, INDArray> call() throws Exception {

    // first part of this callable will be actually run in parallel
    List<String> tokens = tokenizerFactory.create(document.getContent()).getTokens();
    List<VocabWord> documentAsWords = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            documentAsWords.add(vocab.wordFor(token));
        }
    }

    if (documentAsWords.isEmpty())
        throw new ND4JIllegalStateException("Text passed for inference has no matches in model vocabulary.");

    // inference will be single-threaded in java, and parallel in native
    Pair<String, INDArray> result = Pair.makePair(document.getId(), inferVector(documentAsWords));


    countFinished.incrementAndGet();

    if (flag != null)
        flag.incrementAndGet();

    return result;
}
 
Example 4
Source File: CpuWorkspaceDeallocator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) {
    this.pointersPair = workspace.workspace();
    this.pinnedPointers = workspace.pinnedPointers();
    this.externalPointers = workspace.externalPointers();
    this.location = workspace.getWorkspaceConfiguration().getPolicyLocation();

    if (workspace.mappedFileSize() > 0)
        this.mmapInfo = Pair.makePair(workspace.mmap, workspace.mappedFileSize());
}
 
Example 5
Source File: LabelAwareConverter.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<String, String> nextSentence() {
    LabelledDocument document = backingIterator.nextDocument();

    // TODO: probably worth to allow more then one label? i.e. pass same document twice, sequentially
    return Pair.makePair(document.getContent(), document.getLabels().get(0));
}
 
Example 6
Source File: AbstractDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected static Iterable<Pair<float[], float[]>> floatIterable(final int totalRows, final int numColumns) {
    return new Iterable<Pair<float[], float[]>>() {
        @Override
        public Iterator<Pair<float[], float[]>> iterator() {
            return new Iterator<Pair<float[], float[]>>() {
                private AtomicInteger cnt = new AtomicInteger(0);

                @Override
                public boolean hasNext() {
                    return cnt.incrementAndGet() <= totalRows;
                }

                @Override
                public Pair<float[], float[]> next() {
                    float features[] = new float[numColumns];
                    float labels[] = new float[numColumns];
                    for (int i = 0; i < numColumns; i++) {
                        features[i] = (float) i;
                        labels[i] = RandomUtils.nextFloat(0, 5);
                    }
                    return Pair.makePair(features, labels);
                }

                @Override
                public void remove() {
                    // no-op
                }
            };
        }
    };
}
 
Example 7
Source File: CountFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
    // since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
    // anyway we're going to loop through it for elements frequencies
    Counter<Long> localCounter = new Counter<>();
    long seqLen = 0;

    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class
                            .forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm())
                            .newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    driver = ela.getTrainingDriver();

    //System.out.println("Initializing VoidParameterServer in CountFunction");
    VoidParameterServer.getInstance().init(voidConfigurationBroadcast.getValue(), new RoutedTransport(), driver);

    for (T element : sequence.getElements()) {
        if (element == null)
            continue;

        // FIXME: hashcode is bad idea here. we need Long id
        localCounter.incrementCount(element.getStorageId(), 1.0f);
        seqLen++;
    }

    // FIXME: we're missing label information here due to shallow vocab mechanics
    if (sequence.getSequenceLabels() != null)
        for (T label : sequence.getSequenceLabels()) {
            localCounter.incrementCount(label.getStorageId(), 1.0f);
        }

    accumulator.add(localCounter);

    return Pair.makePair(sequence, seqLen);
}
 
Example 8
Source File: ExtraCountFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
    // since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
    // anyway we're going to loop through it for elements frequencies
    ExtraCounter<Long> localCounter = new ExtraCounter<>();
    long seqLen = 0;

    for (T element : sequence.getElements()) {
        if (element == null)
            continue;

        // FIXME: hashcode is bad idea here. we need Long id
        localCounter.incrementCount(element.getStorageId(), 1.0f);
        seqLen++;
    }

    // FIXME: we're missing label information here due to shallow vocab mechanics
    if (sequence.getSequenceLabels() != null)
        for (T label : sequence.getSequenceLabels()) {
            localCounter.incrementCount(label.getStorageId(), 1.0f);
        }

    localCounter.buildNetworkSnapshot();

    accumulator.add(localCounter);

    return Pair.makePair(sequence, seqLen);
}
 
Example 9
Source File: CudaOpContext.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Pair<Long, Long> getRngStates() {
    OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
    return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
}
 
Example 10
Source File: CpuOpContext.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Pair<Long, Long> getRngStates() {
    OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
    return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
}
 
Example 11
Source File: ModelParameterServer.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
/**
 * This method returns pair of integers: iteration number and epoch number
 * @return
 */
public Pair<Integer, Integer> getStartPosition() {
    return Pair.makePair(iterationNumber.get(), epochNumber.get());
}