""""Configuration file for pytest, for testing LISC.""" import pytest import os import shutil import pkg_resources as pkg import nltk from lisc.objects import Counts, Words from lisc.requester import Requester from lisc.core.modutils import safe_import from lisc.utils.db import create_file_structure from lisc.tests.utils import create_files, load_base, load_arts, load_arts_all from lisc.tests.utils import TestDB as TDB plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### def pytest_configure(config): # Set backend for matplotlib tests, if mpl is available if plt: plt.switch_backend('agg') @pytest.fixture(scope='session', autouse=True) def download_data(): # Download required nltk data for tokenizing nltk.download('punkt') nltk.download('stopwords') @pytest.fixture(scope='session', autouse=True) def check_db(): """Once, prior to session, this will clear and re-initialize the test file database.""" # Create the test database directory tests_dir = pkg.resource_filename('lisc', 'tests') test_db_name = 'test_db' # If the directories already exist, clear them if os.path.exists(os.path.join(tests_dir, test_db_name)): shutil.rmtree(os.path.join(tests_dir, test_db_name)) tdb = create_file_structure(tests_dir, test_db_name) create_files(tdb) @pytest.fixture(scope='session') def tdb(): return TDB() @pytest.fixture(scope='session') def tcounts(): return Counts() @pytest.fixture(scope='session') def twords(): return Words() @pytest.fixture(scope='function') def treq(): return Requester() @pytest.fixture(scope='function') def tbase(): return load_base() @pytest.fixture(scope='function') def tbase_terms(): return load_base(True, True) @pytest.fixture(scope='function') def tarts_empty(): return load_arts() @pytest.fixture(scope='function') def tarts_full(): return load_arts(add_data=True, n_data=2) @pytest.fixture(scope='function') def tarts_all(): return load_arts_all()