Java Code Examples for org.deeplearning4j.models.word2vec.wordstore.VocabCache#tokenFor()

The following examples show how to use org.deeplearning4j.models.word2vec.wordstore.VocabCache#tokenFor() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFilterMinWordAddVocab() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
    pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
    Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();

    pipeline.filterMinWordAddVocab(wordFreqCounter);
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();

    assertTrue(vocabCache != null);

    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");


    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);

    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);

    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);

    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);

    sc.stop();
}
 
Example 2
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBuildVocabCache() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();

    assertTrue(vocabCache != null);

    log.info("VocabWords: " + vocabCache.words());
    assertEquals(5, vocabCache.numWords());


    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");

    log.info("Red word: " + redVocab);
    log.info("Flower word: " + flowerVocab);
    log.info("World word: " + worldVocab);
    log.info("Strange word: " + strangeVocab);

    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);

    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);

    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);

    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);

    sc.stop();
}
 
Example 3
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testFirstIteration() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    // word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    /*        Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);
    */
    VocabWord token = vocabCache.tokenFor("strange");
    VocabWord word = vocabCache.wordFor("strange");
    log.info("Strange token: " + token);
    log.info("Strange word: " + word);

    // Get total word count and put into word2vec variable map
    Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    double[] expTable = word2vec.getExpTable();

    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();

    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();

    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD =
                    vocabWordListRDD.zip(sentenceCountCumSumRDD);

    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);

    Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();

    FirstIterationFunction firstIterationFunction = new FirstIterationFunction(
                    word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());

    Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
    assertTrue(ret.hasNext());
}
 
Example 4
Source File: SparkSequenceVectorsTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testFrequenciesCount() throws Exception {
    JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);

    SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();

    seqVec.getConfiguration().setTokenizerFactory(DefaultTokenizerFactory.class.getCanonicalName());
    seqVec.getConfiguration().setElementsLearningAlgorithm("org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram");
    seqVec.setExporter(new SparkModelExporter<VocabWord>() {
        @Override
        public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
            rdd.foreach(new SparkWord2VecTest.TestFn());
        }
    });

    seqVec.fitSequences(sequences);

    Counter<Long> counter = seqVec.getCounter();

    // element "0" should have frequency of 20
    assertEquals(20, counter.getCount(0L), 1e-5);

    // elements 1 - 9 should have frequencies of 10
    for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
        assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
    }


    VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();

    assertEquals(10, shallowVocab.numWords());

    ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
    ShallowSequenceElement first = shallowVocab.tokenFor(1L);

    assertNotEquals(null, zero);
    assertEquals(20.0, zero.getElementFrequency(), 1e-5);
    assertEquals(0, zero.getIndex());

    assertEquals(10.0, first.getElementFrequency(), 1e-5);
}