org.deeplearning4j.models.word2vec.VocabWord Java Examples

The following examples show how to use org.deeplearning4j.models.word2vec.VocabWord. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SentenceBatch.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Train via skip gram
 * @param i the current word
 * @param sentence the sentence to train on
 * @param b
 * @param alpha the learning rate
 */
public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha,
                List<Triple<Integer, Integer, Integer>> changed) {

    final VocabWord word = sentence.get(i);
    int window = param.getWindow();
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(param, word, lastWord, alpha, changed);
                }
            }
        }
    }
}
 
Example #2
Source File: Word2VecPerformerVoid.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * Train via skip gram
 * @param i
 * @param sentence
 */
public void skipGram(int i, List<VocabWord> sentence, int b, double alpha) {

    final VocabWord word = sentence.get(i);
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(word, lastWord, alpha);
                }
            }
        }
    }
}
 
Example #3
Source File: RandomWalkerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testGraphTraverseRandom2() throws Exception {
    RandomWalker<VocabWord> walker = (RandomWalker<VocabWord>) new RandomWalker.Builder<>(graph)
            .setSeed(12345)
            .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20)
            .setWalkDirection(WalkDirection.FORWARD_UNIQUE)
            .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build();

    int cnt = 0;
    while (walker.hasNext()) {
        Sequence<VocabWord> sequence = walker.next();

        assertTrue(sequence.getElements().size() <= 10);
        assertNotEquals(null, sequence);

        for (VocabWord word : sequence.getElements()) {
            assertNotEquals(null, word);
        }

        cnt++;
    }

    assertEquals(10, cnt);
}
 
Example #4
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public Pair<VocabWord, float[]> next() {
    try {
        String word = ReadHelper.readString(stream);
        VocabWord element = new VocabWord(1.0, word);
        element.setIndex(idxCounter.getAndIncrement());

        float[] vector = new float[vectorLength];
        for (int i = 0; i < vectorLength; i++) {
            vector[i] = ReadHelper.readFloat(stream);
        }

        return Pair.makePair(element, vector);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example #5
Source File: FastText.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public VocabCache vocab() {
    if (modelVectorsLoaded) {
        vocabCache = word2Vec.vocab();
    }
    else {
        if (!modelLoaded)
            throw new IllegalStateException("Load model before calling vocab()");

        if (vocabCache == null) {
            vocabCache = new AbstractCache();
        }
        List<String> words = fastTextImpl.getWords();
        for (int i = 0; i < words.size(); ++i) {
            vocabCache.addWordToIndex(i, words.get(i));
            VocabWord word = new VocabWord();
            word.setWord(words.get(i));
            vocabCache.addToken(word);
        }
    }
    return vocabCache;
}
 
Example #6
Source File: ParagraphVectorsTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testGoogleModelForInference() throws Exception {
    WordVectors googleVectors = WordVectorSerializer.readWord2VecModel(new File("/ext/GoogleNews-vectors-negative300.bin.gz"));

    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    ParagraphVectors pv =
                    new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(false)
                                    .trainWordVectors(false).iterations(10).useExistingWordVectors(googleVectors)
                                    .negativeSample(10).sequenceLearningAlgorithm(new DM<VocabWord>()).build();

    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");

    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
 
Example #7
Source File: Word2VecPerformerVoid.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void call(Pair<List<VocabWord>, AtomicLong> pair) throws Exception {
    double numWordsSoFar = wordCount.getValue().doubleValue();

    List<VocabWord> sentence = pair.getFirst();
    double alpha2 = Math.max(minAlpha, alpha * (1 - (1.0 * numWordsSoFar / (double) totalWords)));
    int totalNewWords = 0;
    trainSentence(sentence, alpha2);
    totalNewWords += sentence.size();



    double newWords = totalNewWords + numWordsSoFar;
    double diff = Math.abs(newWords - lastChecked);
    if (diff >= 10000) {
        lastChecked = (int) newWords;
        log.info("Words so far " + newWords + " out of " + totalWords);
    }

    pair.getSecond().getAndAdd((long) totalNewWords);
}
 
Example #8
Source File: ParagraphVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method predicts label of the document.
 * Computes a similarity wrt the mean of the
 * representation of words in the document
 * @param document the document
 * @return the word distances for each label
 */
public String predict(List<VocabWord> document) {
    /*
        This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
     */
    if (document.isEmpty())
        throw new IllegalStateException("Document has no words inside");

    /*
    INDArray arr = Nd4j.create(document.size(), this.layerSize);
    for (int i = 0; i < document.size(); i++) {
        arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
    }*/

    INDArray docMean = inferVector(document); //arr.mean(0);
    Counter<String> distances = new Counter<>();

    for (String s : labelsSource.getLabels()) {
        INDArray otherVec = getWordVectorMatrix(s);
        double sim = Transforms.cosineSim(docMean, otherVec);
        distances.incrementCount(s, (float) sim);
    }

    return distances.argMax();
}
 
Example #9
Source File: ParagraphVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
     * Predict several labels based on the document.
     * Computes a similarity wrt the mean of the
     * representation of words in the document
     * @param document the document
     * @return possible labels in descending order
     */
    public Collection<String> predictSeveral(List<VocabWord> document, int limit) {
        /*
            This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
         */
        if (document.isEmpty())
            throw new IllegalStateException("Document has no words inside");
/*
        INDArray arr = Nd4j.create(document.size(), this.layerSize);
        for (int i = 0; i < document.size(); i++) {
            arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
        }
*/
        INDArray docMean = inferVector(document); //arr.mean(0);
        Counter<String> distances = new Counter<>();

        for (String s : labelsSource.getLabels()) {
            INDArray otherVec = getWordVectorMatrix(s);
            double sim = Transforms.cosineSim(docMean, otherVec);
            log.debug("Similarity inside: [" + s + "] -> " + sim);
            distances.incrementCount(s, (float) sim);
        }

        val keys = distances.keySetSorted();
        return keys.subList(0, Math.min(limit, keys.size()));
    }
 
Example #10
Source File: ParagraphVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method returns top N labels nearest to specified text
 *
 * @param rawText
 * @param topN
 * @return
 */
public Collection<String> nearestLabels(@NonNull String rawText, int topN) {
    List<String> tokens = tokenizerFactory.create(rawText).getTokens();
    List<VocabWord> document = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            document.add(vocab.wordFor(token));
        }
    }

    // we're returning empty collection for empty document
    if (document.isEmpty()) {
        log.info("Document passed to nearestLabels() has no matches in model vocabulary");
        return new ArrayList<>();
    }

    return nearestLabels(document, topN);
}
 
Example #11
Source File: ParagraphVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public INDArray call() throws Exception {

    // first part of this callable will be actually run in parallel
    List<String> tokens = tokenizerFactory.create(document).getTokens();
    List<VocabWord> documentAsWords = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            documentAsWords.add(vocab.wordFor(token));
        }
    }

    if (documentAsWords.isEmpty())
        throw new ND4JIllegalStateException("Text passed for inference has no matches in model vocabulary.");


    // inference will be single-threaded in java, and parallel in native
    INDArray result = inferVector(documentAsWords);

    countFinished.incrementAndGet();

    if (flag != null)
        flag.incrementAndGet();

    return result;
}
 
Example #12
Source File: PerformanceTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testWord2VecCBOWBig() throws Exception {
    SentenceIterator iter = new BasicLineIterator("/home/raver119/Downloads/corpus/namuwiki_raw.txt");
    //iter = new BasicLineIterator("/home/raver119/Downloads/corpus/ru_sentences.txt");
    //SentenceIterator iter = new BasicLineIterator("/ext/DATASETS/ru/Socials/ru_sentences.txt");

    TokenizerFactory t = new KoreanTokenizerFactory();
    //t = new DefaultTokenizerFactory();
    //t.setTokenPreProcessor(new CommonPreprocessor());

    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(5).learningRate(0.025).layerSize(150)
                    .seed(42).sampling(0).negativeSample(0).useHierarchicSoftmax(true).windowSize(5)
                    .modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(8)
                    .allowParallelTokenization(true).tokenizerFactory(t)
                    .elementsLearningAlgorithm(new CBOW<VocabWord>()).build();

    long time1 = System.currentTimeMillis();

    vec.fit();

    long time2 = System.currentTimeMillis();

    log.info("Total execution time: {}", (time2 - time1));
}
 
Example #13
Source File: ParallelTransformerIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public boolean hasNext() {
    //boolean before = underlyingHas;

    //if (underlyingHas.get()) {
        if (buffer.size() < capacity && iterator.hasNextDocument()) {
            CallableTransformer transformer = new CallableTransformer(iterator.nextDocument(), sentenceTransformer);
            Future<Sequence<VocabWord>> futureSequence = executorService.submit(transformer);
            try {
                buffer.put(futureSequence);
            } catch (InterruptedException e) {
                log.error("",e);
            }
        }
      /*  else
            underlyingHas.set(false);

    }
    else {
       underlyingHas.set(false);
    }*/

    return (/*underlyingHas.get() ||*/ !buffer.isEmpty() || /*!stringBuffer.isEmpty() ||*/ processing.get() > 0);
}
 
Example #14
Source File: WeightedWalkerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testBasicIterator1() throws Exception {
    GraphWalker<VocabWord> walker = new WeightedWalker.Builder<>(basicGraph)
                    .setWalkDirection(WalkDirection.FORWARD_PREFERRED).setWalkLength(10)
                    .setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).build();

    int cnt = 0;
    while (walker.hasNext()) {
        Sequence<VocabWord> sequence = walker.next();

        assertNotEquals(null, sequence);
        assertEquals(10, sequence.getElements().size());
        cnt++;
    }

    assertEquals(basicGraph.numVertices(), cnt);
}
 
Example #15
Source File: SequenceVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
private void initIntersectVectors() {
    if (intersectModel != null && intersectModel.vocab().numWords() > 0) {
        List<Integer> indexes = new ArrayList<>();
        for (int i = 0; i < intersectModel.vocab().numWords(); ++i) {
            String externalWord = intersectModel.vocab().wordAtIndex(i);
            int index = this.vocab.indexOf(externalWord);
            if (index >= 0) {
                this.vocab.wordFor(externalWord).setLocked(lockFactor);
                indexes.add(index);
            }
        }

        if (indexes.size() > 0) {
            int[] intersectIndexes = Ints.toArray(indexes);

            Nd4j.scatterUpdate(org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.UpdateOp.ASSIGN,
                    ((InMemoryLookupTable<VocabWord>) lookupTable).getSyn0(),
                    Nd4j.createFromArray(intersectIndexes),
                    ((InMemoryLookupTable<VocabWord>) intersectModel.lookupTable()).getSyn0(),
                    1);
        }
    }
}
 
Example #16
Source File: AbstractCacheTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testHuffman() throws Exception {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    cache.addToken(new VocabWord(1.0, "word"));
    cache.addToken(new VocabWord(2.0, "test"));
    cache.addToken(new VocabWord(3.0, "tester"));

    assertEquals(3, cache.numWords());

    Huffman huffman = new Huffman(cache.tokens());
    huffman.build();
    huffman.applyIndexes(cache);

    assertEquals("tester", cache.wordAtIndex(0));
    assertEquals("test", cache.wordAtIndex(1));
    assertEquals("word", cache.wordAtIndex(2));

    VocabWord word = cache.tokenFor("tester");
    assertEquals(0, word.getIndex());
}
 
Example #17
Source File: WeightedWalkerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Before
public void setUp() throws Exception {
    if (basicGraph == null) {
        // we don't really care about this graph, since it's just basic graph for iteration checks
        basicGraph = new Graph<>(10, false, new AbstractVertexFactory<VocabWord>());

        for (int i = 0; i < 10; i++) {
            basicGraph.getVertex(i).setValue(new VocabWord(i, String.valueOf(i)));

            int x = i + 3;
            if (x >= 10)
                x = 0;
            basicGraph.addEdge(i, x, 1, false);
        }

        basicGraph.addEdge(0, 4, 2, false);
        basicGraph.addEdge(0, 4, 4, false);
        basicGraph.addEdge(0, 4, 6, false);
        basicGraph.addEdge(4, 5, 8, false);
        basicGraph.addEdge(1, 3, 6, false);
        basicGraph.addEdge(9, 7, 4, false);
        basicGraph.addEdge(5, 6, 2, false);
    }
}
 
Example #18
Source File: InMemoryLookupCache.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * @param word
 */
@Override
@Deprecated
public synchronized void putVocabWord(String word) {
    if (word == null || word.isEmpty())
        throw new IllegalArgumentException("Word can't be empty or null");
    // STOP and UNK are not added as tokens
    if (word.equals("STOP") || word.equals("UNK"))
        return;
    VocabWord token = tokenFor(word);
    if (token == null)
        throw new IllegalStateException("Word " + word + " not found as token in vocab");
    int ind = token.getIndex();
    addWordToIndex(ind, word);
    if (!hasToken(word))
        throw new IllegalStateException("Unable to add token " + word + " when not already a token");
    vocabs.put(word, token);
    wordIndex.add(word, token.getIndex());
}
 
Example #19
Source File: ParagraphVectorsTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test(timeout = 300000)
public void testParallelIterator() throws IOException {
    TokenizerFactory factory = new DefaultTokenizerFactory();
    SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt"));

    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true)
            .tokenizerFactory(factory).build();

    BasicTransformerIterator iter = (BasicTransformerIterator)transformer.iterator();
    for (int i = 0; i < 100; ++i) {
        int cnt = 0;
        long counter = 0;
        Sequence<VocabWord> sequence = null;
        while (iter.hasNext()) {
            sequence = iter.next();
            counter += sequence.size();
            cnt++;
        }
        iter.reset();
        assertEquals(757172, counter);
    }
}
 
Example #20
Source File: SecondIterationFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public SecondIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast,
                Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) {

    Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue();
    this.expTable = expTableBroadcast.getValue();
    this.vectorLength = (int) word2vecVarMap.get("vectorLength");
    this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad");
    this.negative = (double) word2vecVarMap.get("negative");
    this.window = (int) word2vecVarMap.get("window");
    this.alpha = (double) word2vecVarMap.get("alpha");
    this.minAlpha = (double) word2vecVarMap.get("minAlpha");
    this.totalWordCount = (long) word2vecVarMap.get("totalWordCount");
    this.seed = (long) word2vecVarMap.get("seed");
    this.maxExp = (int) word2vecVarMap.get("maxExp");
    this.iterations = (int) word2vecVarMap.get("iterations");
    this.batchSize = (int) word2vecVarMap.get("batchSize");

    // this.indexSyn0VecMap = new HashMap<>();
    // this.pointSyn1VecMap = new HashMap<>();

    this.vocab = vocabCacheBroadcast.getValue();


    if (this.vocab == null)
        throw new RuntimeException("VocabCache is null");


}
 
Example #21
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method saves Word2Vec model to output stream
 *
 * @param word2Vec Word2Vec
 * @param stream OutputStream
 */
public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream)
        throws IOException {

    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(word2Vec.getConfiguration())
            .layerSize(word2Vec.getLayerSize()).build();
    vectors.setVocab(word2Vec.getVocab());
    vectors.setLookupTable(word2Vec.getLookupTable());
    vectors.setModelUtils(word2Vec.getModelUtils());
    writeSequenceVectors(vectors, stream);
}
 
Example #22
Source File: TextPipeline.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public Broadcast<VocabCache<VocabWord>> getBroadCastVocabCache() throws IllegalStateException {
    if (vocabCache.numWords() > 0) {
        return vocabCacheBroadcast;
    } else {
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }
}
 
Example #23
Source File: ParagraphVectorsTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677
public void testDirectInference() throws Exception {
    boolean isIntegration = isIntegrationTests();
    File resource = Resources.asFile("/big/raw_sentences.txt");
    SentenceIterator sentencesIter = getIterator(isIntegration, resource);

    ClassPathResource resource_mixed = new ClassPathResource("paravec/");
    File local_resource_mixed = testDir.newFolder();
    resource_mixed.copyDirectory(local_resource_mixed);
    SentenceIterator iter = new AggregatingSentenceIterator.Builder()
                    .addSentenceIterator(sentencesIter)
                    .addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();

    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
                    .learningRate(0.025).layerSize(150).minLearningRate(0.001)
                    .elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
                    .iterate(iter).tokenizerFactory(t).build();

    wordVectors.fit();

    ParagraphVectors pv = new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10)
                    .useHierarchicSoftmax(true).trainWordVectors(true).useExistingWordVectors(wordVectors)
                    .negativeSample(0).sequenceLearningAlgorithm(new DM<VocabWord>()).build();

    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");

    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
 
Example #24
Source File: TextPipeline.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public JavaRDD<List<VocabWord>> getVocabWordListRDD() throws IllegalStateException {
    if (vocabWordListRDD != null) {
        return vocabWordListRDD;
    } else {
        throw new IllegalStateException("IllegalStateException: vocabWordListRDD not set at TextPipline.");
    }
}
 
Example #25
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testIsHeader_withValidHeader () {

    /* Given */
    AbstractCache<VocabWord> cache = new AbstractCache<>();
    String line = "48 100";

    /* When */
    boolean isHeader = WordVectorSerializer.isHeader(line, cache);

    /* Then */
    assertTrue(isHeader);
}
 
Example #26
Source File: SequenceVectorsTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSequenceLearningAlgo1() throws Exception {
    SequenceVectors<VocabWord> vectors =
                    new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5)
                                    .batchSize(250).iterations(1)
                                    .sequenceLearningAlgorithm(
                                                    "org.deeplearning4j.models.embeddings.learning.impl.sequence.DBOW")
                                    .epochs(1).resetModel(false).trainElementsRepresentation(false).build();
}
 
Example #27
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This test checked generations retrieved using stopWords
 *
 * @throws Exception
 */
@Test @Ignore   //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849
public void testZipFunction1() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();

    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();

    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD =
                    vocabWordListRDD.zip(sentenceCountCumSumRDD);
    List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();

    List<VocabWord> vocabWordsList1 = lst.get(0)._1();
    Long cumSumSize1 = lst.get(0)._2();
    assertEquals(3, vocabWordsList1.size());
    assertEquals(vocabWordsList1.get(0).getWord(), "strange");
    assertEquals(vocabWordsList1.get(1).getWord(), "strange");
    assertEquals(vocabWordsList1.get(2).getWord(), "world");
    assertEquals(cumSumSize1, 6L, 0);

    List<VocabWord> vocabWordsList2 = lst.get(1)._1();
    Long cumSumSize2 = lst.get(1)._2();
    assertEquals(2, vocabWordsList2.size());
    assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
    assertEquals(vocabWordsList2.get(1).getWord(), "red");
    assertEquals(cumSumSize2, 9L, 0);

    sc.stop();
}
 
Example #28
Source File: InMemoryLookupCache.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * @param index
 * @param word
 */
@Override
public synchronized void addWordToIndex(int index, String word) {
    if (word == null || word.isEmpty())
        throw new IllegalArgumentException("Word can't be empty or null");



    if (!tokens.containsKey(word)) {
        VocabWord token = new VocabWord(1.0, word);
        tokens.put(word, token);
        wordFrequencies.incrementCount(word, (float) 1.0);
    }

    /*
        If we're speaking about adding any word to index directly, it means it's going to be vocab word, not token
     */
    if (!vocabs.containsKey(word)) {
        VocabWord vw = tokenFor(word);
        vw.setIndex(index);
        vocabs.put(word, vw);
        vw.setIndex(index);
    }

    if (!wordFrequencies.containsElement(word))
        wordFrequencies.incrementCount(word, 1);

    wordIndex.add(word, index);

}
 
Example #29
Source File: ParallelTransformerIteratorTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testCompletes_WhenIteratorHasOneElement() throws Exception {

    String testString = "";
    String[] stringsArray = new String[100];
    for (int i = 0; i < 100; ++i) {
        testString += Integer.toString(i) + " ";
        stringsArray[i] = Integer.toString(i);
    }
    InputStream inputStream = IOUtils.toInputStream(testString, "UTF-8");
    SentenceIterator iterator = new BasicLineIterator(inputStream);

    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true)
            .tokenizerFactory(factory).build();

    Iterator<Sequence<VocabWord>> iter = transformer.iterator();

    Sequence<VocabWord> sequence = null;
    int cnt = 0;
    while (iter.hasNext()) {
        sequence = iter.next();
        List<VocabWord> words = sequence.getElements();
        for (VocabWord word : words) {
            assertEquals(stringsArray[cnt], word.getWord());
            ++cnt;
        }
    }

}
 
Example #30
Source File: RandomWalkerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGraphTraverseRandom6() throws Exception {
    GraphWalker<VocabWord> walker = new RandomWalker.Builder<>(graphDirected).setWalkLength(20)
            .setWalkDirection(WalkDirection.FORWARD_UNIQUE)
            .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build();

    Sequence<VocabWord> sequence = walker.next();
    assertEquals("0", sequence.getElements().get(0).getLabel());
    assertEquals("3", sequence.getElements().get(1).getLabel());
    assertEquals("6", sequence.getElements().get(2).getLabel());
    assertEquals("9", sequence.getElements().get(3).getLabel());

    assertEquals(4, sequence.getElements().size());
}