from os.path import join import pytest import numpy as np from pliers.extractors import BrightnessExtractor, SharpnessExtractor from pliers.graph import Graph from pliers.utils.scikit import PliersTransformer from pliers.stimuli import ImageStim from .utils import get_test_data_path def test_graph_scikit(): pytest.importorskip('pytesseract') pytest.importorskip('sklearn') image_dir = join(get_test_data_path(), 'image') stim1 = join(image_dir, 'apple.jpg') stim2 = join(image_dir, 'button.jpg') graph_spec = join(get_test_data_path(), 'graph', 'simple_graph.json') graph = Graph(spec=graph_spec) trans = PliersTransformer(graph) res = trans.fit_transform([stim1, stim2]) assert res.shape == (2, 1) assert res[0][0] == 4 or res[1][0] == 4 meta = trans.metadata_ assert 'history' in meta.columns assert meta['history'][1] == 'ImageStim->TesseractConverter/TextStim' def test_extractor_scikit(): pytest.importorskip('sklearn') image_dir = join(get_test_data_path(), 'image') stim = ImageStim(join(image_dir, 'apple.jpg')) ext = BrightnessExtractor() trans = PliersTransformer(ext) res = trans.fit_transform(stim) assert res.shape == (1, 1) assert np.isclose(res[0][0], 0.88784294, 1e-5) meta = trans.metadata_ assert np.isnan(meta['onset'][0]) trans = PliersTransformer('BrightnessExtractor') res = trans.fit_transform(stim) assert res.shape == (1, 1) assert np.isclose(res[0][0], 0.88784294, 1e-5) meta = trans.metadata_ assert np.isnan(meta['onset'][0]) def test_within_pipeline(): pytest.importorskip('cv2') pytest.importorskip('sklearn') from sklearn.pipeline import Pipeline from sklearn.preprocessing import Normalizer stim = join(get_test_data_path(), 'image', 'apple.jpg') graph = Graph([BrightnessExtractor(), SharpnessExtractor()]) trans = PliersTransformer(graph) normalizer = Normalizer() pipeline = Pipeline([('pliers', trans), ('normalizer', normalizer)]) res = pipeline.fit_transform(stim) assert res.shape == (1, 2) assert np.isclose(res[0][0], 0.66393, 1e-5) assert np.isclose(res[0][1], 0.74780, 1e-5) meta = trans.metadata_ assert 'onset' in meta.columns assert meta['class'][0] == 'ImageStim'