"""
Test the data loading interface.
"""

import random
from pathlib import Path

from vergeml.data import Data
from vergeml.io import source, SourcePlugin, Sample
from vergeml.operation import operation, OperationPlugin
from vergeml.operations.augment import AugmentOperation

# pylint: disable=C0111

# -------------------------------------------------

def test_data_live_loader_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input=False)
    _test_data_meta(data)

def test_data_mem_loader_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input='mem')
    _test_data_meta(data)

def test_data_disk_loader_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input='disk')
    _test_data_meta(data)

# -------------------------------------------------

def test_data_live_loader_with_ops_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output=False)
    _test_data_meta(data)

def test_data_mem_out_loader_with_ops_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='mem')
    _test_data_meta(data)

def test_data_disk_out_loader_with_ops_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='disk')
    _test_data_meta(data)

def test_data_mem_loader_with_ops_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='mem', cache_output=False)
    _test_data_meta(data)

def test_data_disk_loader_with_ops_meta(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='disk', cache_output=False)
    _test_data_meta(data)

# -------------------------------------------------

def test_data_live_loader_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input=False)
    _test_data_num_samples(data)

def test_data_mem_loader_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input='mem')
    _test_data_num_samples(data)

def test_data_disk_loader_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, cache_input='disk')
    _test_data_num_samples(data)

# -------------------------------------------------

def test_data_live_loader_ops_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output=False)
    _test_data_num_samples(data)

def test_data_mem_out_loader_ops_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='mem')
    _test_data_num_samples(data)

def test_data_disk_out_loader_ops_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='disk')
    _test_data_num_samples(data)

def test_data_mem_loader_ops_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='mem', cache_output=False)
    _test_data_num_samples(data)

def test_data_disk_loader_ops_num_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir)})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='disk', cache_output=False)
    _test_data_num_samples(data)

# -------------------------------------------------

def test_data_mem_loader_read_samples(tmpdir):
    # TEST THIS!
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, cache_input='mem', cache_output=False)
    _test_data_read_samples(data)

def test_data_live_loader_read_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, cache_input=False, cache_output=False)
    _test_data_read_samples(data)

def test_data_disk_loader_read_samples(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, cache_input='disk', cache_output=False)
    _test_data_read_samples(data)

# --------------------------------------------------

def test_data_live_loader_with_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output=False)
    _test_data_read_samples_transformed(data)

def test_data_mem_loader_with_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='mem', cache_output=False)
    _test_data_read_samples_transformed(data)

def test_data_mem_out_loader_with_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='mem')
    _test_data_read_samples_transformed(data)

def test_data_disk_loader_with_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input='disk', cache_output=False)
    _test_data_read_samples_transformed(data)

def test_data_disk_out_loader_with_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AppendStringOperation()],
                cache_input=False, cache_output='disk')
    _test_data_read_samples_transformed(data)

# --------------------------------------------------


def test_data_live_loader_with_multiplier_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output=False)
    _test_data_read_samples_transformed_x2(data)


def test_data_mem_out_loader_with_multiplier_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output='mem')
    _test_data_read_samples_transformed_x2(data)


def test_data_disk_out_loader_with_multiplier_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output='disk')
    _test_data_read_samples_transformed_x2(data)


def test_data_mem_loader_with_multiplier_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input='mem', cache_output=False)
    _test_data_read_samples_transformed_x2(data)


def test_data_disk_loader_with_multiplier_ops(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input='disk', cache_output=False)
    _test_data_read_samples_transformed_x2(data)

# --------------------------------------------------

def test_data_live_loader_with_multiplier_ops_between(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output=False)
    _test_data_read_samples_x2_between(data)

def test_data_mem_out_loader_with_multiplier_ops_between(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output='mem')
    _test_data_read_samples_x2_between(data)

def test_data_disk_out_loader_with_multiplier_ops_between(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input=False, cache_output='disk')
    _test_data_read_samples_x2_between(data)

def test_data_mem_loader_with_multiplier_ops_between(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input='mem', cache_output=False)
    _test_data_read_samples_x2_between(data)

def test_data_disk_loader_with_multiplier_ops_between(tmpdir):
    cache_dir = _prepare_dir(tmpdir)
    src = SourceTest({'samples-dir': str(tmpdir), 'test-split': 2, 'val-split': 2})
    data = Data(input=src, cache_dir=cache_dir, ops=[AugmentOperation(variants=2)],
                cache_input='disk', cache_output=False)
    _test_data_read_samples_x2_between(data)

# ---------------------------------------------------------------------------------

def _test_data_meta(data):
    assert data.meta['some-meta'] == 'meta-value'

def _test_data_num_samples(data):
    assert data.num_samples('train') == 8
    assert data.num_samples('test') == 1
    assert data.num_samples('val') == 1

def _test_data_read_samples(data):
    train_samples = list(data.load('train'))
    assert train_samples == [
        ('content8-transformed', None), ('content2-transformed', None),
        ('content9-transformed', None), ('content3-transformed', None),
        ('content5-transformed', None), ('content7-transformed', None)]
    assert len(train_samples) == 6

    val_samples = list(data.load('val'))#
    assert val_samples == [('content0-transformed', None), ('content1-transformed', None)]
    assert len(val_samples) == 2

    test_samples = list(data.load('test'))#
    assert test_samples == [('content4-transformed', None), ('content6-transformed', None)]
    assert len(test_samples) == 2

def _test_data_read_samples_transformed(data):
    train_samples = list(data.load('train'))
    assert train_samples == [
        ('content8-hello-transformed', None), ('content2-hello-transformed', None),
        ('content9-hello-transformed', None), ('content3-hello-transformed', None),
        ('content5-hello-transformed', None), ('content7-hello-transformed', None)]
    assert len(train_samples) == 6

    val_samples = list(data.load('val'))#
    assert val_samples == [('content0-hello-transformed', None),
                           ('content1-hello-transformed', None)]
    assert len(val_samples) == 2

    test_samples = list(data.load('test'))#
    assert test_samples == [('content4-hello-transformed', None),
                            ('content6-hello-transformed', None)]
    assert len(test_samples) == 2

def _test_data_read_samples_transformed_x2(data):
    train_samples = list(data.load('train'))
    assert train_samples == [
        ('content8-transformed', None), ('content8-transformed', None),
        ('content2-transformed', None), ('content2-transformed', None),
        ('content9-transformed', None), ('content9-transformed', None),
        ('content3-transformed', None), ('content3-transformed', None),
        ('content5-transformed', None), ('content5-transformed', None),
        ('content7-transformed', None), ('content7-transformed', None)]
    assert len(train_samples) == 12

    val_samples = list(data.load('val'))#
    assert val_samples == [('content0-transformed', None), ('content0-transformed', None),
                           ('content1-transformed', None), ('content1-transformed', None)]
    assert len(val_samples) == 4

    test_samples = list(data.load('test'))#
    assert test_samples == [('content4-transformed', None), ('content4-transformed', None),
                            ('content6-transformed', None), ('content6-transformed', None)]
    assert len(test_samples) == 4


def _test_data_read_samples_x2_between(data):
    train_samples = data.load('train')[1:7]
    assert train_samples == \
        [('content8-transformed', None),
         ('content2-transformed', None), ('content2-transformed', None),
         ('content9-transformed', None), ('content9-transformed', None),
         ('content3-transformed', None)]


def _prepare_dir(tmpdir):
    for i in range(0, 10):
        path = tmpdir.join(f"file{i}.test")
        path.write("content" + str(i))
    cache_dir = tmpdir.mkdir('.cache')
    return str(cache_dir)

@source('test-source', 'A test source.', input_patterns="**/*.test") # pylint: disable=W0223
class SourceTest(SourcePlugin):

    def __init__(self, args: None):
        self.files = None
        super().__init__((args or {}).copy())

    def begin_read_samples(self):
        if self.files:
            return

        self.meta['some-meta'] = 'meta-value'

        self.files = self.scan_and_split_files()

    def num_samples(self, split: str) -> int:
        return len(self.files[split])

    def read_samples(self, split, index, n=1):
        items = self.files[split][index:index+n]


        items = [(Path(filename).read_text(), meta) for filename, meta in items]

        res = []
        for item, meta in items:
            rng = random.Random(str(self.random_seed) + meta['filename'])
            res.append(Sample(item, None, meta.copy(), rng))

        return res

    def transform(self, sample):
        sample.x = sample.x + '-transformed'
        sample.y = None
        return sample

    def hash(self, state: str) -> str:
        return super().hash(state + self.hash_files(self.files))


@operation('append')
class AppendStringOperation(OperationPlugin):
    type = str

    def transform(self, data, rng):
        return data + "-hello"