""" Tests views. """ import random import itertools from vergeml.views import IteratorView from vergeml.loader import LiveLoader from vergeml.io import SourcePlugin, source, Sample # pylint: disable=C0111 def test_iterview_default(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train') assert list(map(lambda tp: tp[0], iterview)) == list(range(100)) def test_iterview_infinite(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', infinite=True) assert list(map(lambda tp: tp[0], itertools.islice(iterview, 150))) \ == list(range(100)) + list(range(50)) def test_iterview_random(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', randomize=True, fetch_size=1) assert list(map(lambda tp: tp[0], itertools.islice(iterview, 10))) \ == [92, 1, 43, 61, 35, 73, 48, 18, 98, 36] def test_iterview_random_fetch_size(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', randomize=True, fetch_size=10) assert list(map(lambda tp: tp[0], itertools.islice(iterview, 10))) \ == list(range(70, 80)) def test_iterview_transform(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', transform_x=lambda x: x + 10) assert list(map(lambda tp: tp[0], iterview)) == list(range(10, 110)) def test_iterview_meta(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', with_meta=True) assert next(iterview) == (0, 5, dict(meta=0)) assert next(iterview) == (1, 6, dict(meta=1)) assert next(iterview) == (2, 7, dict(meta=2)) def test_iterview_transform_y(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', transform_y=lambda _: 'transformed_y') assert next(iterview)[1] == 'transformed_y' def test_iterview_val(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'val') assert list(map(lambda tp: tp[0], iterview)) == list(range(10)) def test_iterview_test(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'test') assert list(map(lambda tp: tp[0], iterview)) == list(range(20)) def test_iterview_random2(): loader = LiveLoader('.cache', SourceTest()) iterview = IteratorView(loader, 'train', randomize=True, fetch_size=1) iterview2 = IteratorView(loader, 'train', randomize=True, random_seed=2601, fetch_size=1) assert list(map(lambda tp: tp[0], itertools.islice(iterview2, 10))) \ != list(map(lambda tp: tp[0], itertools.islice(iterview, 10))) @source('test-source', 'A test source.') # pylint: disable=W0223 class SourceTest(SourcePlugin): def __init__(self, args=None): self.data = dict( train=list(range(100)), val=list(range(10)), test=list(range(20)) ) super().__init__(args or {}) def num_samples(self, split: str) -> int: return len(self.data[split]) def read_samples(self, split, index, n=1): items = self.data[split][index: index+n] return [Sample(item, item+5, {'meta': item}, random.Random(self.random_seed + item)) for item in items]