Java Code Examples for org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable#setSyn0()

The following examples show how to use org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable#setSyn0() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);

    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());

    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(RandomUtils.nextInt(1, 100000));
            codes.add(RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(0, 10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }

    InMemoryLookupTable<VocabWord> lookupTable =
                    (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
                                    .vectorLength(configuration.getLayersSize()).cache(cache).build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);

    ParagraphVectors originalVectors =
                    new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();

    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);

    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);

    InMemoryLookupTable<VocabWord> restoredLookupTable =
                    (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();

    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());

    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(),
                        restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }

        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}
 
Example 2
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Loads an in memory cache from the given input stream (sets syn0 and the vocab).
 *
 * @param inputStream  input stream
 * @return a {@link Pair} holding the lookup table and the vocab cache.
 */
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull InputStream inputStream) {
    AbstractCache<VocabWord> cache = new AbstractCache<>();
    LineIterator lines = null;

    try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
         BufferedReader reader = new BufferedReader(inputStreamReader)) {
        lines = IOUtils.lineIterator(reader);

        String line = null;
        boolean hasHeader = false;

        /* Check if first line is a header */
        if (lines.hasNext()) {
            line = lines.nextLine();
            hasHeader = isHeader(line, cache);
        }

        if (hasHeader) {
            log.debug("First line is a header");
            line = lines.nextLine();
        }

        List<INDArray> arrays = new ArrayList<>();
        long[] vShape = new long[]{ 1, -1 };

        do {
            String[] tokens = line.split(" ");
            String word = ReadHelper.decodeB64(tokens[0]);
            VocabWord vocabWord = new VocabWord(1.0, word);
            vocabWord.setIndex(cache.numWords());

            cache.addToken(vocabWord);
            cache.addWordToIndex(vocabWord.getIndex(), word);
            cache.putVocabWord(word);

            float[] vector = new float[tokens.length - 1];
            for (int i = 1; i < tokens.length; i++) {
                vector[i - 1] = Float.parseFloat(tokens[i]);
            }

            vShape[1] = vector.length;
            INDArray row = Nd4j.create(vector, vShape);

            arrays.add(row);

            line = lines.hasNext() ? lines.next() : null;
        } while (line != null);

        INDArray syn = Nd4j.vstack(arrays);

        InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
                .Builder<VocabWord>()
                .vectorLength(arrays.get(0).columns())
                .useAdaGrad(false)
                .cache(cache)
                .useHierarchicSoftmax(false)
                .build();

        lookupTable.setSyn0(syn);

        return new Pair<>((InMemoryLookupTable) lookupTable, (VocabCache) cache);
    } catch (IOException readeTextStreamException) {
        throw new RuntimeException(readeTextStreamException);
    } finally {
        if (lines != null) {
            lines.close();
        }
    }
}
 
Example 3
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * This method can be used to load previously saved model from InputStream (like a HDFS-stream)
 * <p>
 * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
 *
 * @param stream        InputStream that contains previously serialized model
 * @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
 * @return
 * @throws IOException
 * @deprecated Use readWord2VecModel() or loadStaticModel() method instead
 */
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
    String line = "";
    List<INDArray> arrays = new ArrayList<>();

    if (skipFirstLine)
        reader.readLine();

    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        String word = split[0].replaceAll(WHITESPACE_REPLACEMENT, " ");
        VocabWord word1 = new VocabWord(1.0, word);

        word1.setIndex(cache.numWords());

        cache.addToken(word1);

        cache.addWordToIndex(word1.getIndex(), word);

        cache.putVocabWord(word);

        float[] vector = new float[split.length - 1];

        for (int i = 1; i < split.length; i++) {
            vector[i - 1] = Float.parseFloat(split[i]);
        }

        INDArray row = Nd4j.create(vector);

        arrays.add(row);
    }

    InMemoryLookupTable<VocabWord> lookupTable =
            (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
                    .vectorLength(arrays.get(0).columns()).cache(cache).build();

    INDArray syn = Nd4j.vstack(arrays);

    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
 
Example 4
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * This method loads previously saved SequenceVectors model from InputStream
 *
 * @param factory
 * @param stream
 * @param <T>
 * @return
 */
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(
        @NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
    BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));

    // at first we load vectors configuration
    String line = reader.readLine();
    VectorsConfiguration configuration =
            VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));

    AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();


    List<INDArray> rows = new ArrayList<>();

    while ((line = reader.readLine()) != null) {
        if (line.isEmpty()) // skip empty line
            continue;
        ElementPair pair = ElementPair.fromEncodedJson(line);
        T element = factory.deserialize(pair.getObject());
        rows.add(Nd4j.create(pair.getVector()));
        vocabCache.addToken(element);
        vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
    }

    reader.close();

    InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>()
            .vectorLength(rows.get(0).columns()).cache(vocabCache).build(); // fix: add vocab cache

    /*
     * INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns()); for (int x = 0; x < rows.size(); x++) {
     * syn0.putRow(x, rows.get(x)); }
     */
    INDArray syn0 = Nd4j.vstack(rows);

    lookupTable.setSyn0(syn0);

    SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache)
            .lookupTable(lookupTable).resetModel(false).build();

    return vectors;
}
 
Example 5
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void sequenceVectorsCorrect_WhenDeserialized() {

    INDArray syn0 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);

    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
            .Builder<VocabWord>()
            .useAdaGrad(false)
            .cache(cache)
            .build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    lookupTable.setSyn1Neg(syn1Neg);

    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).
            vocabCache(cache).
            lookupTable(lookupTable).
            build();
    SequenceVectors<VocabWord> deser = null;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        WordVectorSerializer.writeSequenceVectors(vectors, baos);
        byte[] bytesResult = baos.toByteArray();
        deser = WordVectorSerializer.readSequenceVectors(new ByteArrayInputStream(bytesResult), true);
    } catch (Exception e) {
        log.error("",e);
        fail();
    }

    assertNotNull(vectors.getConfiguration());
    assertEquals(vectors.getConfiguration(), deser.getConfiguration());

    assertEquals(cache.totalWordOccurrences(),deser.vocab().totalWordOccurrences());
    assertEquals(cache.totalNumberOfDocs(), deser.vocab().totalNumberOfDocs());
    assertEquals(cache.numWords(), deser.vocab().numWords());

    for (int i = 0; i < cache.words().size(); ++i) {
        val cached = cache.wordAtIndex(i);
        val restored = deser.vocab().wordAtIndex(i);
        assertNotNull(cached);
        assertEquals(cached, restored);
    }

}
 
Example 6
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void W2V_Correct_WhenDeserialized() {

    INDArray syn0 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);

    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
            .Builder<VocabWord>()
            .useAdaGrad(false)
            .cache(cache)
            .build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    lookupTable.setSyn1Neg(syn1Neg);

    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).
            vocabCache(cache).
            lookupTable(lookupTable).
            layerSize(200).
            modelUtils(new BasicModelUtils<VocabWord>()).
            build();

    Word2Vec word2Vec = new Word2Vec.Builder(vectors.getConfiguration())
            .vocabCache(vectors.vocab())
            .lookupTable(lookupTable)
            .modelUtils(new FlatModelUtils<VocabWord>())
            .limitVocabularySize(1000)
            .elementsLearningAlgorithm(CBOW.class.getCanonicalName())
            .allowParallelTokenization(true)
            .usePreciseMode(true)
            .batchSize(1024)
            .windowSize(23)
            .minWordFrequency(24)
            .iterations(54)
            .seed(45)
            .learningRate(0.08)
            .epochs(45)
            .stopWords(Collections.singletonList("NOT"))
            .sampling(44)
            .workers(45)
            .negativeSample(56)
            .useAdaGrad(true)
            .useHierarchicSoftmax(false)
            .minLearningRate(0.002)
            .resetModel(true)
            .useUnknown(true)
            .enableScavenger(true)
            .usePreciseWeightInit(true)
            .build();

    Word2Vec deser = null;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        WordVectorSerializer.writeWord2Vec(word2Vec, baos);
        byte[] bytesResult = baos.toByteArray();
        deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
    } catch (Exception e) {
        log.error("",e);
        fail();
    }

    assertNotNull(word2Vec.getConfiguration());
    assertEquals(word2Vec.getConfiguration(), deser.getConfiguration());

    assertEquals(cache.totalWordOccurrences(),deser.vocab().totalWordOccurrences());
    assertEquals(cache.totalNumberOfDocs(), deser.vocab().totalNumberOfDocs());
    assertEquals(cache.numWords(), deser.vocab().numWords());

    for (int i = 0; i < cache.words().size(); ++i) {
        val cached = cache.wordAtIndex(i);
        val restored = deser.vocab().wordAtIndex(i);
        assertNotNull(cached);
        assertEquals(cached, restored);
    }

}
 
Example 7
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void ParaVec_Correct_WhenDeserialized() {

    INDArray syn0 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);

    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
            .Builder<VocabWord>()
            .useAdaGrad(false)
            .cache(cache)
            .build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    lookupTable.setSyn1Neg(syn1Neg);

    ParagraphVectors paragraphVectors = new ParagraphVectors.Builder()
            .vocabCache(cache)
            .lookupTable(lookupTable)
            .build();

    Word2Vec deser = null;
    try {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        WordVectorSerializer.writeWord2Vec(paragraphVectors, baos);
        byte[] bytesResult = baos.toByteArray();
        deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
    } catch (Exception e) {
        log.error("",e);
        fail();
    }

    assertNotNull(paragraphVectors.getConfiguration());
    assertEquals(paragraphVectors.getConfiguration(), deser.getConfiguration());

    assertEquals(cache.totalWordOccurrences(),deser.vocab().totalWordOccurrences());
    assertEquals(cache.totalNumberOfDocs(), deser.vocab().totalNumberOfDocs());
    assertEquals(cache.numWords(), deser.vocab().numWords());

    for (int i = 0; i < cache.words().size(); ++i) {
        val cached = cache.wordAtIndex(i);
        val restored = deser.vocab().wordAtIndex(i);
        assertNotNull(cached);
        assertEquals(cached, restored);
    }

}
 
Example 8
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void weightLookupTable_Correct_WhenDeserialized() throws Exception {

    INDArray syn0 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
            syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);

    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
            .Builder<VocabWord>()
            .useAdaGrad(false)
            .cache(cache)
            .build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    lookupTable.setSyn1Neg(syn1Neg);

    File dir = testDir.newFolder();
    File file = new File(dir, "lookupTable.txt");

    WeightLookupTable<VocabWord> deser = null;
    try {
        WordVectorSerializer.writeLookupTable(lookupTable, file);
        deser = WordVectorSerializer.readLookupTable(file);
    } catch (Exception e) {
        log.error("",e);
        fail();
    }
    assertEquals(lookupTable.getVocab().totalWordOccurrences(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().totalWordOccurrences());
    assertEquals(cache.totalNumberOfDocs(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().totalNumberOfDocs());
    assertEquals(cache.numWords(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().numWords());

    for (int i = 0; i < cache.words().size(); ++i) {
        val cached = cache.wordAtIndex(i);
        val restored = ((InMemoryLookupTable<VocabWord>)deser).getVocab().wordAtIndex(i);
        assertNotNull(cached);
        assertEquals(cached, restored);
    }

    assertEquals(lookupTable.getSyn0().columns(), ((InMemoryLookupTable<VocabWord>) deser).getSyn0().columns());
    assertEquals(lookupTable.getSyn0().rows(), ((InMemoryLookupTable<VocabWord>) deser).getSyn0().rows());
    for (int c = 0; c < ((InMemoryLookupTable<VocabWord>) deser).getSyn0().columns(); ++c) {
        for (int r = 0; r < ((InMemoryLookupTable<VocabWord>) deser).getSyn0().rows(); ++r) {
            assertEquals(lookupTable.getSyn0().getDouble(r,c),
                        ((InMemoryLookupTable<VocabWord>) deser).getSyn0().getDouble(r,c), 1e-5);
        }
    }
}