org.deeplearning4j.models.word2vec.wordstore.VocabCache Java Examples

The following examples show how to use org.deeplearning4j.models.word2vec.wordstore.VocabCache. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkWord2VecTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testStringsTokenization1() throws Exception {
    JavaRDD<String> rddSentences = sc.parallelize(sentences);

    SparkWord2Vec word2Vec = new SparkWord2Vec();
    word2Vec.getConfiguration().setTokenizerFactory(DefaultTokenizerFactory.class.getCanonicalName());
    word2Vec.getConfiguration().setElementsLearningAlgorithm("org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram");
    word2Vec.setExporter(new SparkModelExporter<VocabWord>() {
        @Override
        public void export(JavaRDD<ExportContainer<VocabWord>> rdd) {
            rdd.foreach(new TestFn());
        }
    });


    word2Vec.fitSentences(rddSentences);

    VocabCache<ShallowSequenceElement> vocabCache = word2Vec.getShallowVocabCache();

    assertNotEquals(null, vocabCache);

    assertEquals(9, vocabCache.numWords());
    assertEquals(2.0, vocabCache.wordFor(SequenceElement.getLongHash("one")).getElementFrequency(), 1e-5);
    assertEquals(1.0, vocabCache.wordFor(SequenceElement.getLongHash("two")).getElementFrequency(), 1e-5);
}
 
Example #2
Source File: SparkSequenceVectors.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method builds shadow vocabulary and huffman tree
 *
 * @param counter
 * @return
 */
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {

    // TODO: need simplified cache here, that will operate on Long instead of string labels
    VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
    for (Long id : counter.keySet()) {
        ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
        vocabCache.addToken(shallowElement);
    }

    // building huffman tree
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();
    huffman.applyIndexes(vocabCache);

    return vocabCache;
}
 
Example #3
Source File: VocabHolder.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public INDArray getSyn0Vector(Integer wordIndex, VocabCache<VocabWord> vocabCache) {
    if (!workers.contains(Thread.currentThread().getId()))
        workers.add(Thread.currentThread().getId());

    VocabWord word = vocabCache.elementAtIndex(wordIndex);

    if (!indexSyn0VecMap.containsKey(word)) {
        synchronized (this) {
            if (!indexSyn0VecMap.containsKey(word)) {
                indexSyn0VecMap.put(word, getRandomSyn0Vec(vectorLength.get(), wordIndex));
            }
        }
    }

    return indexSyn0VecMap.get(word);
}
 
Example #4
Source File: FastText.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public VocabCache vocab() {
    if (modelVectorsLoaded) {
        vocabCache = word2Vec.vocab();
    }
    else {
        if (!modelLoaded)
            throw new IllegalStateException("Load model before calling vocab()");

        if (vocabCache == null) {
            vocabCache = new AbstractCache();
        }
        List<String> words = fastTextImpl.getWords();
        for (int i = 0; i < words.size(); ++i) {
            vocabCache.addWordToIndex(i, words.get(i));
            VocabWord word = new VocabWord();
            word.setWord(words.get(i));
            vocabCache.addToken(word);
        }
    }
    return vocabCache;
}
 
Example #5
Source File: DM.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable,
                @NonNull VectorsConfiguration configuration) {
    this.vocabCache = vocabCache;
    this.lookupTable = lookupTable;
    this.configuration = configuration;

    cbow.configure(vocabCache, lookupTable, configuration);

    this.window = configuration.getWindow();
    this.useAdaGrad = configuration.isUseAdaGrad();
    this.negative = configuration.getNegative();
    this.sampling = configuration.getSampling();

    this.syn0 = ((InMemoryLookupTable<T>) lookupTable).getSyn0();
    this.syn1 = ((InMemoryLookupTable<T>) lookupTable).getSyn1();
    this.syn1Neg = ((InMemoryLookupTable<T>) lookupTable).getSyn1Neg();
    this.expTable = ((InMemoryLookupTable<T>) lookupTable).getExpTable();
    this.table = ((InMemoryLookupTable<T>) lookupTable).getTable();
}
 
Example #6
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method saves specified SequenceVectors model to target  OutputStream
 *
 * @param vectors SequenceVectors model
 * @param factory SequenceElementFactory implementation for your objects
 * @param stream  Target output stream
 * @param <T>
 */
public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors,
                                                                    @NonNull SequenceElementFactory<T> factory, @NonNull OutputStream stream) throws IOException {
    WeightLookupTable<T> lookupTable = vectors.getLookupTable();
    VocabCache<T> vocabCache = vectors.getVocab();

    try (PrintWriter writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(stream, StandardCharsets.UTF_8)))) {

        // at first line we save VectorsConfiguration
        writer.write(vectors.getConfiguration().toEncodedJson());

        // now we have elements one by one
        for (int x = 0; x < vocabCache.numWords(); x++) {
            T element = vocabCache.elementAtIndex(x);
            String json = factory.serialize(element);
            INDArray d = Nd4j.create(1);
            double[] vector = lookupTable.vector(element.getLabel()).dup().data().asDouble();
            ElementPair pair = new ElementPair(json, vector);
            writer.println(pair.toEncodedJson());
            writer.flush();
        }
    }
}
 
Example #7
Source File: TSNEVisualizationExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    Nd4j.setDataType(DataBuffer.Type.DOUBLE);
    List<String> cacheList = new ArrayList<>();
    File file = new File("words.txt");
    String outputFile = "tsne-standard-coords.csv";
    Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(file);
    VocabCache cache = vectors.getSecond();
    INDArray weights = vectors.getFirst().getSyn0();

    for(int i=0;i<cache.numWords();i++){
        cacheList.add(cache.wordAtIndex(i));
    }

    BarnesHutTsne tsne = new BarnesHutTsne.Builder()
                                            .setMaxIter(100)
                                            .theta(0.5)
                                            .normalize(false)
                                            .learningRate(500)
                                            .useAdaGrad(false)
                                            .build();

    tsne.fit(weights);
    tsne.saveAsFile(cacheList,outputFile);

}
 
Example #8
Source File: TSNEVisualizationExample.java    From Java-Deep-Learning-Cookbook with MIT License 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    Nd4j.setDataType(DataBuffer.Type.DOUBLE);
    List<String> cacheList = new ArrayList<>();
    File file = new File("words.txt");
    String outputFile = "tsne-standard-coords.csv";
    Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(file);
    VocabCache cache = vectors.getSecond();
    INDArray weights = vectors.getFirst().getSyn0();

    for(int i=0;i<cache.numWords();i++){
        cacheList.add(cache.wordAtIndex(i));
    }

    BarnesHutTsne tsne = new BarnesHutTsne.Builder()
                                            .setMaxIter(100)
                                            .theta(0.5)
                                            .normalize(false)
                                            .learningRate(500)
                                            .useAdaGrad(false)
                                            .build();

    tsne.fit(weights);
    tsne.saveAsFile(cacheList,outputFile);

}
 
Example #9
Source File: VocabHolder.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public Iterable<Map.Entry<VocabWord, INDArray>> getSplit(VocabCache<VocabWord> vocabCache) {
    Set<Map.Entry<VocabWord, INDArray>> set = new HashSet<>();
    int cnt = 0;
    for (Map.Entry<VocabWord, INDArray> entry : indexSyn0VecMap.entrySet()) {
        set.add(entry);
        cnt++;
        if (cnt > 10)
            break;
    }

    System.out.println("Returning set: " + set.size());

    return set;
}
 
Example #10
Source File: CBOW.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable,
                @NonNull VectorsConfiguration configuration) {
    this.vocabCache = vocabCache;
    this.lookupTable = lookupTable;
    this.configuration = configuration;

    this.window = configuration.getWindow();
    this.useAdaGrad = configuration.isUseAdaGrad();
    this.negative = configuration.getNegative();
    this.sampling = configuration.getSampling();

    if (configuration.getNegative() > 0) {
        if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
            logger.info("Initializing syn1Neg...");
            ((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
            ((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
            ((InMemoryLookupTable<T>) lookupTable).resetWeights(false);
        }
    }


    this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
    this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
    this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
    //this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable()));
    this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable(),
            new long[]{((InMemoryLookupTable<T>) lookupTable).getExpTable().length}, syn0.get().dataType()));
    this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
    this.variableWindows = configuration.getVariableWindows();
}
 
Example #11
Source File: NegativeHolder.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public synchronized void initHolder(@NonNull VocabCache<VocabWord> vocabCache, double[] expTable, int layerSize) {
    if (!wasInit.get()) {
        this.vocab = vocabCache;
        this.syn1Neg = Nd4j.zeros(vocabCache.numWords(), layerSize);
        makeTable(Math.max(expTable.length, 100000), 0.75);
        wasInit.set(true);
    }
}
 
Example #12
Source File: SecondIterationFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public SecondIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast,
                Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) {

    Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue();
    this.expTable = expTableBroadcast.getValue();
    this.vectorLength = (int) word2vecVarMap.get("vectorLength");
    this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad");
    this.negative = (double) word2vecVarMap.get("negative");
    this.window = (int) word2vecVarMap.get("window");
    this.alpha = (double) word2vecVarMap.get("alpha");
    this.minAlpha = (double) word2vecVarMap.get("minAlpha");
    this.totalWordCount = (long) word2vecVarMap.get("totalWordCount");
    this.seed = (long) word2vecVarMap.get("seed");
    this.maxExp = (int) word2vecVarMap.get("maxExp");
    this.iterations = (int) word2vecVarMap.get("iterations");
    this.batchSize = (int) word2vecVarMap.get("batchSize");

    // this.indexSyn0VecMap = new HashMap<>();
    // this.pointSyn1VecMap = new HashMap<>();

    this.vocab = vocabCacheBroadcast.getValue();


    if (this.vocab == null)
        throw new RuntimeException("VocabCache is null");


}
 
Example #13
Source File: InMemoryVocabStoreTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testStorePut() {
    VocabCache<VocabWord> cache = new InMemoryLookupCache();
    assertFalse(cache.containsWord("hello"));
    cache.addWordToIndex(0, "hello");
    assertTrue(cache.containsWord("hello"));
    assertEquals(1, cache.numWords());
    assertEquals("hello", cache.wordAtIndex(0));
}
 
Example #14
Source File: InMemoryLookupCache.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void importVocabulary(VocabCache<VocabWord> vocabCache) {
    for (VocabWord word : vocabCache.vocabWords()) {
        if (vocabs.containsKey(word.getLabel())) {
            wordFrequencies.incrementCount(word.getLabel(), (float) word.getElementFrequency());
        } else {
            tokens.put(word.getLabel(), word);
            vocabs.put(word.getLabel(), word);
            wordFrequencies.incrementCount(word.getLabel(), (float) word.getElementFrequency());
        }
        totalWordOccurrences.addAndGet((long) word.getElementFrequency());
    }
}
 
Example #15
Source File: AbstractCache.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method imports all elements from VocabCache passed as argument
 * If element already exists,
 *
 * @param vocabCache
 */
public void importVocabulary(@NonNull VocabCache<T> vocabCache) {
    AtomicBoolean added = new AtomicBoolean(false);
    for (T element : vocabCache.vocabWords()) {
        if (this.addToken(element))
            added.set(true);
    }
    //logger.info("Current state: {}; Adding value: {}", this.documentsCounter.get(), vocabCache.totalNumberOfDocs());
    if (added.get())
        this.documentsCounter.addAndGet(vocabCache.totalNumberOfDocs());
}
 
Example #16
Source File: Huffman.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method updates VocabCache and all it's elements with Huffman indexes
 * Please note: it should be the same VocabCache as was used for Huffman tree initialization
 *
 * @param cache VocabCache to be updated.
 */
public void applyIndexes(VocabCache<? extends SequenceElement> cache) {
    if (!buildTrigger)
        build();

    for (int a = 0; a < words.size(); a++) {
        if (words.get(a).getLabel() != null) {
            cache.addWordToIndex(a, words.get(a).getLabel());
        } else {
            cache.addWordToIndex(a, words.get(a).getStorageId());
        }

        words.get(a).setIndex(a);
    }
}
 
Example #17
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method saves paragraph vectors to the given output stream.
 *
 * @deprecated Use {@link #writeParagraphVectors(ParagraphVectors, OutputStream)}
 */
@Deprecated
public static void writeWordVectors(ParagraphVectors vectors, OutputStream stream) {
    try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(stream, StandardCharsets.UTF_8))) {
        /*
        This method acts similary to w2v csv serialization, except of additional tag for labels
         */

        VocabCache<VocabWord> vocabCache = vectors.getVocab();
        for (VocabWord word : vocabCache.vocabWords()) {
            StringBuilder builder = new StringBuilder();

            builder.append(word.isLabel() ? "L" : "E").append(" ");
            builder.append(word.getLabel().replaceAll(" ", WHITESPACE_REPLACEMENT)).append(" ");

            INDArray vector = vectors.getWordVectorMatrix(word.getLabel());
            for (int j = 0; j < vector.length(); j++) {
                builder.append(vector.getDouble(j));
                if (j < vector.length() - 1) {
                    builder.append(" ");
                }
            }

            writer.write(builder.append("\n").toString());
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example #18
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method reads vocab cache from provided InputStream.
 * Please note: it reads only vocab content, so it's suitable mostly for BagOfWords/TF-IDF vectorizers
 *
 * @param stream
 * @return
 * @throws IOException
 */
public static VocabCache<VocabWord> readVocabCache(@NonNull InputStream stream) throws IOException {
    val vocabCache = new AbstractCache.Builder<VocabWord>().build();
    val factory = new VocabWordFactory();
    boolean firstLine = true;
    long totalWordOcc = -1L;
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
        String line;
        while ((line = reader.readLine()) != null) {
            // try to treat first line as header with 3 digits
            if (firstLine) {
                firstLine = false;
                val split = line.split("\\ ");

                if (split.length != 3)
                    continue;

                try {
                    vocabCache.setTotalDocCount(Long.valueOf(split[1]));
                    totalWordOcc = Long.valueOf(split[2]);
                    continue;
                } catch (NumberFormatException e) {
                    // no-op
                }
            }

            val word = factory.deserialize(line);

            vocabCache.addToken(word);
            vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
        }
    }

    if (totalWordOcc >= 0)
        vocabCache.setTotalWordOccurences(totalWordOcc);

    return vocabCache;
}
 
Example #19
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method loads Word2Vec model from csv file
 *
 * @param inputStream  input stream
 * @return Word2Vec model
 */
public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
    VectorsConfiguration configuration = new VectorsConfiguration();

    // let's try to load this file as csv file
    try {
        log.debug("Trying CSV model restoration...");

        Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(inputStream);
        Word2Vec.Builder builder = new Word2Vec
                .Builder()
                .lookupTable(pair.getFirst())
                .useAdaGrad(false)
                .vocabCache(pair.getSecond())
                .layerSize(pair.getFirst().layerSize())
                // we don't use hs here, because model is incomplete
                .useHierarchicSoftmax(false)
                .resetModel(false);

        TokenizerFactory factory = getTokenizerFactory(configuration);
        if (factory != null) {
            builder.tokenizerFactory(factory);
        }

        return builder.build();
    } catch (Exception ex) {
        throw new RuntimeException("Unable to load model in CSV format");
    }
}
 
Example #20
Source File: FirstIterationFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public FirstIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast,
                                     Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) {

    Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue();
    this.expTable = expTableBroadcast.getValue();
    this.vectorLength = (int) word2vecVarMap.get("vectorLength");
    this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad");
    this.negative = (double) word2vecVarMap.get("negative");
    this.window = (int) word2vecVarMap.get("window");
    this.alpha = (double) word2vecVarMap.get("alpha");
    this.minAlpha = (double) word2vecVarMap.get("minAlpha");
    this.totalWordCount = (long) word2vecVarMap.get("totalWordCount");
    this.seed = (long) word2vecVarMap.get("seed");
    this.maxExp = (int) word2vecVarMap.get("maxExp");
    this.iterations = (int) word2vecVarMap.get("iterations");
    this.batchSize = (int) word2vecVarMap.get("batchSize");
    this.indexSyn0VecMap = new HashMap<>();
    this.pointSyn1VecMap = new HashMap<>();
    this.vocab = vocabCacheBroadcast.getValue();

    if (this.vocab == null)
        throw new RuntimeException("VocabCache is null");

    if (negative > 0) {
        negativeHolder = NegativeHolder.getInstance();
        negativeHolder.initHolder(vocab, expTable, this.vectorLength);
    }
}
 
Example #21
Source File: TextPipeline.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public Broadcast<VocabCache<VocabWord>> getBroadCastVocabCache() throws IllegalStateException {
    if (vocabCache.numWords() > 0) {
        return vocabCacheBroadcast;
    } else {
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }
}
 
Example #22
Source File: TextPipeline.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public VocabCache<VocabWord> getVocabCache() throws IllegalStateException {
    if (vocabCache != null && vocabCache.numWords() > 0) {
        return vocabCache;
    } else {
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }
}
 
Example #23
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testFilterMinWordAddVocab() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
    pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);
    Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();

    pipeline.filterMinWordAddVocab(wordFreqCounter);
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();

    assertTrue(vocabCache != null);

    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");


    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);

    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);

    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);

    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);

    sc.stop();
}
 
Example #24
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBuildVocabCache() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();

    assertTrue(vocabCache != null);

    log.info("VocabWords: " + vocabCache.words());
    assertEquals(5, vocabCache.numWords());


    VocabWord redVocab = vocabCache.tokenFor("red");
    VocabWord flowerVocab = vocabCache.tokenFor("flowers");
    VocabWord worldVocab = vocabCache.tokenFor("world");
    VocabWord strangeVocab = vocabCache.tokenFor("strange");

    log.info("Red word: " + redVocab);
    log.info("Flower word: " + flowerVocab);
    log.info("World word: " + worldVocab);
    log.info("Strange word: " + strangeVocab);

    assertEquals(redVocab.getWord(), "red");
    assertEquals(redVocab.getElementFrequency(), 1, 0);

    assertEquals(flowerVocab.getWord(), "flowers");
    assertEquals(flowerVocab.getElementFrequency(), 1, 0);

    assertEquals(worldVocab.getWord(), "world");
    assertEquals(worldVocab.getElementFrequency(), 1, 0);

    assertEquals(strangeVocab.getWord(), "strange");
    assertEquals(strangeVocab.getElementFrequency(), 2, 0);

    sc.stop();
}
 
Example #25
Source File: TextPipelineTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSyn0AfterFirstIteration() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    Huffman huffman = new Huffman(vocabCache.vocabWords());
    huffman.build();

    // Get total word count and put into word2vec variable map
    Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap();
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    double[] expTable = word2vec.getExpTable();

    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();

    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();

    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD =
                    vocabWordListRDD.zip(sentenceCountCumSumRDD);

    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);

    FirstIterationFunction firstIterationFunction = new FirstIterationFunction(word2vecVarMapBroadcast,
                    expTableBroadcast, pipeline.getBroadCastVocabCache());
    JavaRDD<Pair<VocabWord, INDArray>> pointSyn0Vec = vocabWordListSentenceCumSumRDD
                    .mapPartitions(firstIterationFunction).map(new MapToPairFunction());
}
 
Example #26
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testIndexPersistence() throws Exception {
    File inputFile = Resources.asFile("big/raw_sentences.txt");
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100)
                    .stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5)
                    .iterate(iter).tokenizerFactory(t).build();

    vec.fit();

    VocabCache orig = vec.getVocab();

    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeWordVectors(vec, tempFile);

    WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);

    VocabCache rest = vec2.vocab();

    assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());

    for (VocabWord word : vec.getVocab().vocabWords()) {
        INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
        INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());

        assertEquals(array1, array2);
    }
}
 
Example #27
Source File: PartitionTrainingFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public PartitionTrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast,
                @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast,
                @NonNull Broadcast<VoidConfiguration> paramServerConfigurationBroadcast) {
    this.vocabCacheBroadcast = vocabCacheBroadcast;
    this.configurationBroadcast = vectorsConfigurationBroadcast;
    this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast;
}
 
Example #28
Source File: DistributedFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public DistributedFunction(@NonNull Broadcast<VoidConfiguration> configurationBroadcast,
                @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast,
                @NonNull Broadcast<VocabCache<ShallowSequenceElement>> shallowVocabBroadcast) {
    this.configurationBroadcast = configurationBroadcast;
    this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast;
    this.shallowVocabBroadcast = shallowVocabBroadcast;
}
 
Example #29
Source File: TrainingFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public TrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast,
                @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast,
                @NonNull Broadcast<VoidConfiguration> paramServerConfigurationBroadcast) {
    this.vocabCacheBroadcast = vocabCacheBroadcast;
    this.configurationBroadcast = vectorsConfigurationBroadcast;
    this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast;
}
 
Example #30
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull File file) {
    try (InputStream inputStream = fileStream(file)) {
        return loadTxt(inputStream);
    } catch (IOException readTestException) {
        throw new RuntimeException(readTestException);
    }
}