#! /usr/bin/env python
# -*- coding: utf-8 -*-

import tempfile
import numpy as np
import os
import unittest

from tensorpack.dataflow import HDF5Serializer, LMDBSerializer, NumpySerializer, TFRecordSerializer
from tensorpack.dataflow.base import DataFlow


def delete_file_if_exists(fn):
    try:
        os.remove(fn)
    except OSError:
        pass


class SeededFakeDataFlow(DataFlow):
    """docstring for SeededFakeDataFlow"""

    def __init__(self, seed=42, size=32):
        super(SeededFakeDataFlow, self).__init__()
        self.seed = seed
        self._size = size
        self.cache = []

    def reset_state(self):
        np.random.seed(self.seed)
        for _ in range(self._size):
            label = np.random.randint(low=0, high=10)
            img = np.random.randn(28, 28, 3)
            self.cache.append([label, img])

    def __len__(self):
        return self._size

    def __iter__(self):
        for dp in self.cache:
            yield dp


class SerializerTest(unittest.TestCase):

    def run_write_read_test(self, file, serializer, w_args, w_kwargs, r_args, r_kwargs, error_msg):
        try:
            delete_file_if_exists(file)

            ds_expected = SeededFakeDataFlow()
            serializer.save(ds_expected, file, *w_args, **w_kwargs)
            ds_actual = serializer.load(file, *r_args, **r_kwargs)

            ds_actual.reset_state()
            ds_expected.reset_state()

            for dp_expected, dp_actual in zip(ds_expected.__iter__(), ds_actual.__iter__()):
                self.assertEqual(dp_expected[0], dp_actual[0])
                self.assertTrue(np.allclose(dp_expected[1], dp_actual[1]))
        except ImportError:
            print(error_msg)

    def test_lmdb(self):
        with tempfile.TemporaryDirectory() as f:
            self.run_write_read_test(
                os.path.join(f, 'test.lmdb'),
                LMDBSerializer,
                {}, {},
                {}, {'shuffle': False},
                'Skip test_lmdb, no lmdb available')

    def test_tfrecord(self):
        with tempfile.TemporaryDirectory() as f:
            self.run_write_read_test(
                os.path.join(f, 'test.tfrecord'),
                TFRecordSerializer,
                {}, {},
                {}, {'size': 32},
                'Skip test_tfrecord, no tensorflow available')

    def test_numpy(self):
        with tempfile.TemporaryDirectory() as f:
            self.run_write_read_test(
                os.path.join(f, 'test.npz'),
                NumpySerializer,
                {}, {},
                {}, {'shuffle': False},
                'Skip test_numpy, no numpy available')

    def test_hdf5(self):
        args = [['label', 'image']]
        with tempfile.TemporaryDirectory() as f:
            self.run_write_read_test(
                os.path.join(f, 'test.h5'),
                HDF5Serializer,
                args, {},
                args, {'shuffle': False},
                'Skip test_hdf5, no h5py available')


if __name__ == '__main__':
    unittest.main()