import os import pytest from gensim.models import Word2Vec, Doc2Vec from gensim.models.doc2vec import TaggedDocument from gensim.utils import simple_preprocess from vec4ir import Doc2VecInference, Retrieval, Matching, Tfidf, WordCentroidDistance, build_analyzer, WordMoversDistance documents = ["The quick brown fox jumps over the lazy dog", "Computer scientists are lazy lazy lazy"] def test_build_analyzer(): analyzer =build_analyzer('sklearn', False, lowercase=True) analyzed = analyzer("the quick brown fox") assert analyzed[-1] == "fox" DEFAULT_ANALYZER = build_analyzer('sklearn', stop_words=False, lowercase=True) TEST_FILE = "test.tmp" def test_matching(): match_op = Matching() match_op.fit(documents) matched = match_op.predict("fox") assert matched == [0] def test_tfidf(): # Test tfidf retrieval with auto-generated ids tfidf = Tfidf() tfidf.fit(documents) result = tfidf.query('lazy') assert result[0] == 1 assert result[1] == 0 def test_retrieval(): # Test retrieval with given ids tfidf = Tfidf() retrieval = Retrieval(tfidf) ids = ['fox_example', 'lazy_example'] retrieval.fit(documents, ids) result = retrieval.query('fox') assert result[0] == 'fox_example' assert result[1] == 'lazy_example' def test_word2vec(): model = Word2Vec([doc.split() for doc in documents], iter=1, min_count=1) match_op = Matching() with pytest.raises(ValueError): wcd = WordCentroidDistance(model) wcd = WordCentroidDistance(model.wv) retrieval = Retrieval(wcd, matching=match_op) retrieval.fit(documents) result = retrieval.query('dog') assert result[0] == 0 def test_combined(): model = Word2Vec([doc.split() for doc in documents], iter=1, min_count=1) wcd = WordCentroidDistance(model.wv) tfidf = Tfidf() wcd.fit(documents) # # they can operate on different feilds tfidf.fit(['fox', 'scientists']) match_op = Matching().fit(documents) combined = wcd + tfidf ** 2 retrieval = Retrieval(combined, matching=match_op, labels=[7,42]) result = retrieval.query('fox') assert result[0] == 7 result = retrieval.query('scientists') assert result[0] == 42 # # PYEMD is required # def test_wordmovers(): # model = Word2Vec([doc.split() for doc in documents], iter=1, min_count=1) # match_op = Matching() # wmd = WordMoversDistance(model.wv) # retrieval = Retrieval(wmd, matching=match_op) # retrieval.fit(documents) # result = retrieval.query('dog') # assert result[0] == 0 def test_doc2vec_inference(): tagged_docs = [TaggedDocument(simple_preprocess(doc), [i]) for i, doc in enumerate(documents)] model = Doc2Vec(tagged_docs, epochs=1, min_count=1) d2v = Doc2VecInference(model, DEFAULT_ANALYZER) match_op = Matching() retrieval = Retrieval(d2v, matching=match_op).fit(documents) result = retrieval.query("scientists") assert result[0] == 1 def test_doc2vec_inference_saveload(): tagged_docs = [TaggedDocument(simple_preprocess(doc), [i]) for i, doc in enumerate(documents)] model = Doc2Vec(tagged_docs, epochs=1, min_count=1, vector_size=10) model.save(TEST_FILE) del model model = Doc2Vec.load(TEST_FILE) os.remove(TEST_FILE) d2v = Doc2VecInference(model, DEFAULT_ANALYZER) match_op = Matching() retrieval = Retrieval(d2v, matching=match_op).fit(documents) result = retrieval.query("scientists") assert result[0] == 1