import os
import pathlib
import shutil
import tempfile
import unittest

import pfio

try:
    import chainer
    from chainer.training import extensions
    from chainer import testing
    chainer_available = True

    # They depend on Chainer
    from pfio.chainer_extensions.snapshot_writers import SimpleWriter
    from pfio.chainer_extensions import load_snapshot
except Exception:
    chainer_available = False


@unittest.skipIf(not chainer_available, "Chainer is not available")
def test_scan_directory():
    from pfio.chainer_extensions.snapshot import _scan_directory
    with tempfile.TemporaryDirectory() as td:
        files = ['tmpfoobar_10', 'foobar_10', 'foobar_123', 'tmpfoobar_10234']
        for file in files:
            pathlib.Path(os.path.join(td, file)).touch()

        latest = _scan_directory(pfio, td)
        assert latest is not None
        assert 'foobar_123' == latest


@unittest.skipIf(not chainer_available, "Chainer is not available")
def test_snapshot():
    trainer = testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with tempfile.TemporaryDirectory() as td:
        writer = SimpleWriter(td)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)
        assert 'snapshot_iter_0' in os.listdir(td)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, td, fail_on_no_file=True)


@unittest.skipIf(shutil.which('hdfs') is None, "HDFS client not installed")
@unittest.skipIf(not chainer_available, "Chainer is not available")
def test_snapshot_hdfs():
    trainer = chainer.testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with pfio.create_handler('hdfs') as fs:
        tmpdir = "some-pfio-tmp-dir"
        fs.makedirs(tmpdir, exist_ok=True)
        file_list = list(fs.list(tmpdir))
        assert len(file_list) == 0

        writer = SimpleWriter(tmpdir, fs=fs)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)

        assert 'snapshot_iter_0' in fs.list(tmpdir)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, tmpdir, fs=fs, fail_on_no_file=True)

        # Cleanup
        fs.remove(tmpdir, recursive=True)