import unittest from unittest.mock import MagicMock import luigi import pandas as pd from sklearn.ensemble import RandomForestClassifier from redshells.train import TrainPairwiseSimilarityModel class _DummyTask(luigi.Task): pass class TrainPairwiseSimilarityModelTest(unittest.TestCase): def setUp(self): self.input_data = dict() self.dump_data = None TrainPairwiseSimilarityModel.clear_instance_cache() def test_run(self): self.input_data['item2embedding'] = dict(i0=[1, 2], i1=[3, 4]) self.input_data['similarity_data'] = pd.DataFrame( dict(item1=['i0', 'i0', 'i1'], item2=['i0', 'i1', 'i1'], similarity=[1, 0, 1])) task = TrainPairwiseSimilarityModel( item2embedding_task=_DummyTask(), similarity_data_task=_DummyTask(), model_name='RandomForestClassifier', item0_column_name='item1', item1_column_name='item2', similarity_column_name='similarity') task.load = MagicMock(side_effect=self._load) task.dump = MagicMock(side_effect=self._dump) task.run() self.assertIsInstance(self.dump_data, RandomForestClassifier) def _load(self, *args, **kwargs): if 'target' in kwargs: return self.input_data.get(kwargs['target'], None) if len(args) > 0: return self.input_data.get(args[0], None) return self.input_data def _dump(self, *args): self.dump_data = args[0] if __name__ == '__main__': unittest.main()