Java Code Examples for org.deeplearning4j.models.embeddings.loader.WordVectorSerializer#readWord2VecModel()

The following examples show how to use org.deeplearning4j.models.embeddings.loader.WordVectorSerializer#readWord2VecModel() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ParagraphVectorsTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testGoogleModelForInference() throws Exception {
    WordVectors googleVectors = WordVectorSerializer.readWord2VecModel(new File("/ext/GoogleNews-vectors-negative300.bin.gz"));

    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    ParagraphVectors pv =
                    new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(false)
                                    .trainWordVectors(false).iterations(10).useExistingWordVectors(googleVectors)
                                    .negativeSample(10).sequenceLearningAlgorithm(new DM<VocabWord>()).build();

    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");

    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
 
Example 2
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testUnifiedLoaderArchive1() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, false);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1Neg());
}
 
Example 3
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method tests binary file loading as static model
 *
 * @throws Exception
 */
@Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testStaticLoaderBinary() throws Exception {

    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.readWord2VecModel(binaryFile);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(binaryFile);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
Example 4
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testFromTableAndVocab() throws IOException {

    WordVectors vec = WordVectorSerializer.readWord2VecModel(textFile);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();

    WordVectors wordVectors = WordVectorSerializer.fromTableAndVocab(lookupTable, lookupCache);
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
Example 5
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void testWriteWordVectors() throws IOException {
    WordVectors vec = WordVectorSerializer.readWord2VecModel(binaryFile);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
    WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);

    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
Example 6
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testUnifiedLoaderArchive2() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(w2v, true);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    assertNotEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
 
Example 7
Source File: Word2VecTests.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Ignore
@Test
public void testWord2VecGoogleModelUptraining() throws Exception {
    long time1 = System.currentTimeMillis();
    Word2Vec vec = WordVectorSerializer.readWord2VecModel(
                    new File("C:\\Users\\raver\\Downloads\\GoogleNews-vectors-negative300.bin.gz"), false);
    long time2 = System.currentTimeMillis();
    log.info("Model loaded in {} msec", time2 - time1);
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    vec.setTokenizerFactory(t);
    vec.setSentenceIterator(iter);
    vec.getConfiguration().setUseHierarchicSoftmax(false);
    vec.getConfiguration().setNegative(5.0);
    vec.setElementsLearningAlgorithm(new CBOW<VocabWord>());

    vec.fit();
}
 
Example 8
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
/**
 * This method tests CSV file loading via unified loader
 *
 * @throws Exception
 */
@Test
public void testUnifiedLoaderText() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(textFile, true);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    // we're trying EXTENDED model, but file doesn't have syn1/huffman info, so it should be silently degraded to simplified model
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
 
Example 9
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testLoaderTextSmall() throws Exception {
    INDArray vec = Nd4j.create(new double[] {0.002001, 0.002210, -0.001915, -0.001639, 0.000683, 0.001511, 0.000470,
                    0.000106, -0.001802, 0.001109, -0.002178, 0.000625, -0.000376, -0.000479, -0.001658, -0.000941,
                    0.001290, 0.001513, 0.001485, 0.000799, 0.000772, -0.001901, -0.002048, 0.002485, 0.001901,
                    0.001545, -0.000302, 0.002008, -0.000247, 0.000367, -0.000075, -0.001492, 0.000656, -0.000669,
                    -0.001913, 0.002377, 0.002190, -0.000548, -0.000113, 0.000255, -0.001819, -0.002004, 0.002277,
                    0.000032, -0.001291, -0.001521, -0.001538, 0.000848, 0.000101, 0.000666, -0.002107, -0.001904,
                    -0.000065, 0.000572, 0.001275, -0.001585, 0.002040, 0.000463, 0.000560, -0.000304, 0.001493,
                    -0.001144, -0.001049, 0.001079, -0.000377, 0.000515, 0.000902, -0.002044, -0.000992, 0.001457,
                    0.002116, 0.001966, -0.001523, -0.001054, -0.000455, 0.001001, -0.001894, 0.001499, 0.001394,
                    -0.000799, -0.000776, -0.001119, 0.002114, 0.001956, -0.000590, 0.002107, 0.002410, 0.000908,
                    0.002491, -0.001556, -0.000766, -0.001054, -0.001454, 0.001407, 0.000790, 0.000212, -0.001097,
                    0.000762, 0.001530, 0.000097, 0.001140, -0.002476, 0.002157, 0.000240, -0.000916, -0.001042,
                    -0.000374, -0.001468, -0.002185, -0.001419, 0.002139, -0.000885, -0.001340, 0.001159, -0.000852,
                    0.002378, -0.000802, -0.002294, 0.001358, -0.000037, -0.001744, 0.000488, 0.000721, -0.000241,
                    0.000912, -0.001979, 0.000441, 0.000908, -0.001505, 0.000071, -0.000030, -0.001200, -0.001416,
                    -0.002347, 0.000011, 0.000076, 0.000005, -0.001967, -0.002481, -0.002373, -0.002163, -0.000274,
                    0.000696, 0.000592, -0.001591, 0.002499, -0.001006, -0.000637, -0.000702, 0.002366, -0.001882,
                    0.000581, -0.000668, 0.001594, 0.000020, 0.002135, -0.001410, -0.001303, -0.002096, -0.001833,
                    -0.001600, -0.001557, 0.001222, -0.000933, 0.001340, 0.001845, 0.000678, 0.001475, 0.001238,
                    0.001170, -0.001775, -0.001717, -0.001828, -0.000066, 0.002065, -0.001368, -0.001530, -0.002098,
                    0.001653, -0.002089, -0.000290, 0.001089, -0.002309, -0.002239, 0.000721, 0.001762, 0.002132,
                    0.001073, 0.001581, -0.001564, -0.001820, 0.001987, -0.001382, 0.000877, 0.000287, 0.000895,
                    -0.000591, 0.000099, -0.000843, -0.000563});
    String w1 = "database";
    String w2 = "DBMS";
    WordVectors vecModel = WordVectorSerializer.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.txt").getFile());
    WordVectors vectorsBinary = WordVectorSerializer.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
    INDArray textWeights = vecModel.lookupTable().getWeights();
    INDArray binaryWeights = vectorsBinary.lookupTable().getWeights();
    Collection<String> nearest = vecModel.wordsNearest("database", 10);
    Collection<String> nearestBinary = vectorsBinary.wordsNearest("database", 10);
    System.out.println(nearestBinary);
    assertEquals(vecModel.similarity("DBMS", "DBMS's"), vectorsBinary.similarity("DBMS", "DBMS's"), 1e-1);

}
 
Example 10
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method tests binary file loading via unified loader
 *
 * @throws Exception
 */
@Test
public void testUnifiedLoaderBinary() throws Exception {

    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.readWord2VecModel(binaryFile);
    WordVectors vectorsStatic = WordVectorSerializer.readWord2VecModel(binaryFile, false);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
Example 11
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testLoaderBinary() throws IOException {
    WordVectors vec = WordVectorSerializer.readWord2VecModel(binaryFile);
    assertEquals(vec.vocab().numWords(), 30);
    assertTrue(vec.vocab().hasToken("Morgan_Freeman"));
    assertTrue(vec.vocab().hasToken("JA_Montalbano"));
    double[] wordVector1 = vec.getWordVector("Morgan_Freeman");
    double[] wordVector2 = vec.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
Example 12
Source File: Word2VecLoader.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Instances getStructure() throws IOException {
  if (m_sourceFile == null) {
    throw new IOException("No source has been specified.");
  }

  if (m_structure == null) {
    setSource(m_sourceFile);
    this.vec = WordVectorSerializer.readWord2VecModel(m_sourceFile);
    this.setStructure();
  }

  return m_structure;
}
 
Example 13
Source File: WordVectorSerializerTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBackwardsCompatibleWord2Vec() {
    File model_v3 = Resources.asFile("deeplearning4j-nlp/model_beta3.zip");
    File model_v4 = Resources.asFile("deeplearning4j-nlp/model_beta4.zip");
    Word2Vec word2Vec1 = WordVectorSerializer.readWord2VecModel(model_v3, true);
    Word2Vec word2Vec2 = WordVectorSerializer.readWord2VecModel(model_v4, true);
    try {
        assertEquals(word2Vec1.toJson(), word2Vec2.toJson());
    } catch (Exception e) {
        fail(e.getMessage());
    }
}
 
Example 14
Source File: Word2VecLoader.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
@Override
public Instances getStructure() throws IOException {
  if (m_sourceFile == null) {
    throw new IOException("No source has been specified.");
  }

  if (m_structure == null) {
    setSource(m_sourceFile);
    this.vec = WordVectorSerializer.readWord2VecModel(m_sourceFile);
    this.setStructure();
  }

  return m_structure;
}
 
Example 15
Source File: GoogleNewsVectorExample.java    From Java-Deep-Learning-Cookbook with MIT License 5 votes vote down vote up
public static void main(String[] args) {
    try{
        File file = new File("{PATH-TO-GOOGLE-WORD-VECTOR}");
        Word2Vec model = WordVectorSerializer.readWord2VecModel(file);
        System.out.println(Arrays.asList(model.wordsNearest("season",10)));
    } catch(ND4JIllegalStateException e){
        System.out.println("Please provide proper directory path in place of: PATH-TO-GOOGLE-WORD-VECTOR");
    }
}
 
Example 16
Source File: TestCnnSentenceDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCnnSentenceDataSetIteratorNoTokensEdgeCase() throws Exception {

    WordVectors w2v = WordVectorSerializer
                    .readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());

    int vectorSize = w2v.lookupTable().layerSize();

    List<String> sentences = new ArrayList<>();
    //First 2 sentences - no valid words
    sentences.add("NOVALID WORDSHERE");
    sentences.add("!!!");
    sentences.add("these balance Database model");
    sentences.add("into same THISWORDDOESNTEXIST are");
    int maxLength = 4;
    List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
    List<String> s2 = Arrays.asList("into", "same", "are");

    List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");

    INDArray expLabels = Nd4j.create(new float[][] {{0, 1}, {1, 0}}); //Order of labels: alphabetic. Positive -> [0,1]


    LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
    CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN2D).sentenceProvider(p).wordVectors(w2v)
                    .useNormalizedWordVectors(true)
                    .maxSentenceLength(256).minibatchSize(32).sentencesAlongHeight(false).build();

    //            System.out.println("alongHeight = " + alongHeight);
    DataSet ds = dsi.next();

    INDArray expectedFeatures = Nd4j.create(DataType.FLOAT, 2, 1, vectorSize, maxLength);

    INDArray expectedFeatureMask = Nd4j.create(new float[][] {{1, 1, 1, 1}, {1, 1, 1, 0}}).reshape('c', 2, 1, 1, 4);

    for (int i = 0; i < 4; i++) {
        expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
                        NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrixNormalized(s1.get(i)));
    }

    for (int i = 0; i < 3; i++) {
        expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
                        NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrixNormalized(s2.get(i)));
    }

    assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
    assertEquals(expectedFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
    assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
    assertNull(ds.getLabelsMaskArray());


    //Sanity check on single sentence loading:
    INDArray allKnownWords = dsi.loadSingleSentence("these balance");
    INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
    assertNotNull(allKnownWords);
    assertNotNull(withUnknown);

    try {
        dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
        fail("Expected exception");
    } catch (Throwable t){
        String m = t.getMessage();
        assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary"));
    }
}
 
Example 17
Source File: FastText.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void loadPretrainedVectors(File vectorsFile) {
    word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
    modelVectorsLoaded = true;
    log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
            vectorsFile.getAbsolutePath());
}
 
Example 18
Source File: Word2VecTestsSmall.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Before
public void setUp() throws Exception {
    word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
}
 
Example 19
Source File: TestCnnSentenceDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCnnSentenceDataSetIteratorUseUnknownVector() throws Exception {

    WordVectors w2v = WordVectorSerializer
            .readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());

    List<String> sentences = new ArrayList<>();
    sentences.add("these balance Database model");
    sentences.add("into same THISWORDDOESNTEXIST are");
    //Last 2 sentences - no valid words
    sentences.add("NOVALID WORDSHERE");
    sentences.add("!!!");

    List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");


    LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
    CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN1D)
            .unknownWordHandling(CnnSentenceDataSetIterator.UnknownWordHandling.UseUnknownVector)
            .sentenceProvider(p).wordVectors(w2v)
            .useNormalizedWordVectors(true)
            .maxSentenceLength(256).minibatchSize(4).sentencesAlongHeight(false).build();

    assertTrue(dsi.hasNext());
    DataSet ds = dsi.next();

    assertFalse(dsi.hasNext());

    INDArray f = ds.getFeatures();
    assertEquals(4, f.size(0));

    INDArray unknown = w2v.getWordVectorMatrix(w2v.getUNK());
    if(unknown == null)
        unknown = Nd4j.create(DataType.FLOAT, f.size(1));

    assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(0)));
    assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
    assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(3)));

    assertEquals(unknown, f.get(NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.point(0)));
    assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));

    //Sanity check on single sentence loading:
    INDArray allKnownWords = dsi.loadSingleSentence("these balance");
    INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
    INDArray allUnknown = dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
    assertNotNull(allKnownWords);
    assertNotNull(withUnknown);
    assertNotNull(allUnknown);
}
 
Example 20
Source File: TestCnnSentenceDataSetIterator.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testCnnSentenceDataSetIteratorNoValidTokensNextEdgeCase() throws Exception {
    //Case: 2 minibatches, of size 2
    //First minibatch: OK
    //Second minibatch: would be empty
    //Therefore: after first minibatch is returned, .hasNext() should return false

    WordVectors w2v = WordVectorSerializer
                    .readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());

    int vectorSize = w2v.lookupTable().layerSize();

    List<String> sentences = new ArrayList<>();
    sentences.add("these balance Database model");
    sentences.add("into same THISWORDDOESNTEXIST are");
    //Last 2 sentences - no valid words
    sentences.add("NOVALID WORDSHERE");
    sentences.add("!!!");
    int maxLength = 4;
    List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
    List<String> s2 = Arrays.asList("into", "same", "are");

    List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");

    INDArray expLabels = Nd4j.create(new float[][] {{0, 1}, {1, 0}}); //Order of labels: alphabetic. Positive -> [0,1]


    LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
    CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN2D).sentenceProvider(p).wordVectors(w2v)
                    .useNormalizedWordVectors(true)
                    .maxSentenceLength(256).minibatchSize(2).sentencesAlongHeight(false).build();

    assertTrue(dsi.hasNext());
    DataSet ds = dsi.next();

    assertFalse(dsi.hasNext());


    INDArray expectedFeatures = Nd4j.create(2, 1, vectorSize, maxLength);

    INDArray expectedFeatureMask = Nd4j.create(new float[][] {{1, 1, 1, 1}, {1, 1, 1, 0}}).reshape('c', 2, 1, 1, 4);

    for (int i = 0; i < 4; i++) {
        expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
                        NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrixNormalized(s1.get(i)));
    }

    for (int i = 0; i < 3; i++) {
        expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
                        NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrixNormalized(s2.get(i)));
    }

    assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
    assertEquals(expectedFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
    assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
    assertNull(ds.getLabelsMaskArray());
}