"""Test the aboleth utilities."""

from types import GeneratorType

import numpy as np
import tensorflow as tf

import aboleth as ab


def test_batch():
    """Test the batch feed dict generator."""
    X = np.arange(100)
    fd = {'X': X}

    data = ab.batch(fd, batch_size=10, n_iter=10)

    # Make sure this is a generator
    assert isinstance(data, GeneratorType)

    # Make sure we get a dict back of a length we expect
    d = next(data)
    assert isinstance(d, dict)
    assert 'X' in d
    assert len(d['X']) == 10

    # Test we get all of X back in one sweep of the data
    accum = list(d['X'])
    for ds in data:
        assert len(ds['X']) == 10
        accum.extend(list(ds['X']))

    assert len(accum) == len(X)
    assert set(X) == set(accum)


def test_batch_predict():
    """Test the batch prediction feed dict generator."""
    X = np.arange(100)
    fd = {'X': X}

    data = ab.batch_prediction(fd, batch_size=10)

    # Make sure this is a generator
    assert isinstance(data, GeneratorType)

    # Make sure we get a dict back of a length we expect with correct indices
    for ind, d in data:
        assert isinstance(d, dict)
        assert 'X' in d
        assert len(d['X']) == 10
        assert all(X[ind] == d['X'])