Java Code Examples for org.deeplearning4j.models.word2vec.VocabWord#markAsLabel()

The following examples show how to use org.deeplearning4j.models.word2vec.VocabWord#markAsLabel() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: WordVectorSerializer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method restores ParagraphVectors model previously saved with writeParagraphVectors()
 *
 * @return
 */
public static ParagraphVectors readParagraphVectors(File file) throws IOException {
    Word2Vec w2v = readWord2Vec(file);

    // and "convert" it to ParaVec model + optionally trying to restore labels information
    ParagraphVectors vectors = new ParagraphVectors.Builder(w2v.getConfiguration())
            .vocabCache(w2v.getVocab())
            .lookupTable(w2v.getLookupTable())
            .resetModel(false)
            .build();

    try (ZipFile zipFile = new ZipFile(file)) {
        // now we try to restore labels information
        ZipEntry labels = zipFile.getEntry("labels.txt");
        if (labels != null) {
            InputStream stream = zipFile.getInputStream(labels);
            try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) {
                String line;
                while ((line = reader.readLine()) != null) {
                    VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim()));
                    if (word != null) {
                        word.markAsLabel(true);
                    }
                }
            }
        }
    }

    vectors.extractLabels();

    return vectors;
}
 
Example 2
Source File: DocumentSequenceConvertFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public Sequence<VocabWord> call(LabelledDocument document) throws Exception {
    Sequence<VocabWord> sequence = new Sequence<>();

    // get elements
    if (document.getReferencedContent() != null && !document.getReferencedContent().isEmpty()) {
        sequence.addElements(document.getReferencedContent());
    } else {
        if (tokenizerFactory == null)
            instantiateTokenizerFactory();

        List<String> tokens = tokenizerFactory.create(document.getContent()).getTokens();

        for (String token : tokens) {
            if (token == null || token.isEmpty())
                continue;

            VocabWord word = new VocabWord(1.0, token);
            sequence.addElement(word);
        }
    }

    // get labels
    for (String label : document.getLabels()) {
        if (label == null || label.isEmpty())
            continue;

        VocabWord labelElement = new VocabWord(1.0, label);
        labelElement.markAsLabel(true);

        sequence.addSequenceLabel(labelElement);
    }

    return sequence;
}
 
Example 3
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);

    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());

    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(RandomUtils.nextInt(1, 100000));
            codes.add(RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(0, 10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }

    InMemoryLookupTable<VocabWord> lookupTable =
                    (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
                                    .vectorLength(configuration.getLayersSize()).cache(cache).build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);

    ParagraphVectors originalVectors =
                    new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();

    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);

    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);

    InMemoryLookupTable<VocabWord> restoredLookupTable =
                    (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();

    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());

    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(),
                        restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }

        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}