from unittest import TestCase, main, mock
from schema import SchemaError
import copy
import argparse
import skelebot as sb

class TestRepository(TestCase):

    artifcatory = None
    s3 = None
    s3_subfolder = None

    artifactoryDict = {
        "url": "test",
        "repo": "test",
        "path": "path"
    }

    s3Dict = {
        "bucket": "my-bucket",
        "region": "us-east-1",
        "profile": "test"
    }

    artifactDict = {
        "name": "test",
        "file": "test"
    }

    def setUp(self):
        artifact = sb.components.repository.repository.Artifact("test", "test.pkl")
        artifact2 = sb.components.repository.repository.Artifact("test2", "test2.pkl")
        artifactoryRepo = sb.components.repository.artifactoryRepo.ArtifactoryRepo("artifactory.test.com", "ml", "test")
        s3Repo = sb.components.repository.s3Repo.S3Repo("my-bucket", "us-east-1", "test")
        s3Repo_path = sb.components.repository.s3Repo.S3Repo("my-bucket/sub/folder", "us-east-1", "test")

        self.artifactory = sb.components.repository.repository.Repository([artifact, artifact2], s3=None, artifactory=artifactoryRepo)
        self.s3 = sb.components.repository.repository.Repository([artifact, artifact2], s3=s3Repo, artifactory=None)
        self.s3_subfolder = sb.components.repository.repository.Repository([artifact], s3=s3Repo_path, artifactory=None)

    def test_repository_load(self):
        artifact = sb.components.repository.repository.Artifact("test", "test.pkl")

        sb.components.repository.repository.Repository.load({"artifacts": [], "s3": self.s3Dict})
        sb.components.repository.repository.Repository.load({"artifacts": [], "artifactory": self.artifactoryDict})

        try:
            sb.components.repository.repository.Repository.load({"artifacts": []})
            self.fail("Exception Not Thrown")
        except SchemaError as exc:
            self.assertEqual(str(exc), "Repository must contain 's3' or 'artifactory' config")

    def test_addParsers_artifactory(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
        subparsers = parser.add_subparsers(dest="job")
        subparsers = self.artifactory.addParsers(subparsers)

        self.assertNotEqual(subparsers.choices["push"], None)
        self.assertNotEqual(subparsers.choices["pull"], None)

    def test_addParsers_s3(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
        subparsers = parser.add_subparsers(dest="job")
        subparsers = self.s3.addParsers(subparsers)

        self.assertNotEqual(subparsers.choices["push"], None)
        self.assertNotEqual(subparsers.choices["pull"], None)

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('os.rename')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_push_conflict_artifactory(self, mock_artifactory, mock_rename, mock_input):
        mock_input.return_value = "abc"
        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=False, artifact='test', user=None, token=None)
        expectedException = "This artifact version already exists. Please bump the version or use the force parameter (-f) to overwrite the artifact."

        try:
            self.artifactory.execute(config, args)
            self.fail("Exception Not Thrown")
        except Exception as exc:
            self.assertEqual(str(exc), expectedException)
            mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v1.0.0.pkl", auth=('abc', 'abc'))

    @mock.patch('boto3.Session')
    def test_execute_push_conflict_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_client.list_objects_v2.return_value = {"Contents": [{"Key": "test_v1.0.0.pkl"}]}
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=False, artifact='test', user='sean', token='abc123')
        expectedException = "This artifact version already exists. Please bump the version or use the force parameter (-f) to overwrite the artifact."

        try:
            self.s3.execute(config, args)
            self.fail("Exception Not Thrown")
        except Exception as exc:
            self.assertEqual(str(exc), expectedException)
            mock_client.list_objects_v2.assert_called_with(Bucket="my-bucket", Prefix="test_v1.0.0.pkl")

    @mock.patch('shutil.copyfile')
    @mock.patch('os.remove')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_push_error_artifactory(self, mock_artifactory, mock_remove, mock_copy):
        mock_path = mock.MagicMock()
        mock_path.deploy_file = mock.MagicMock(side_effect=KeyError('foo'))
        mock_artifactory.return_value = mock_path

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='test', user='sean', token='abc123')

        with self.assertRaises(KeyError):
            self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v1.0.0.pkl", auth=('sean', 'abc123'))
        mock_copy.assert_called_with("test.pkl", "test_v1.0.0.pkl")
        mock_remove.assert_called_with("test_v1.0.0.pkl")

    @mock.patch('shutil.copyfile')
    @mock.patch('os.remove')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_push_artifactory(self, mock_artifactory, mock_remove, mock_copy):
        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='test', user='sean', token='abc123')

        self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v1.0.0.pkl", auth=('sean', 'abc123'))
        mock_copy.assert_called_with("test.pkl", "test_v1.0.0.pkl")
        mock_remove.assert_called_with("test_v1.0.0.pkl")

    @mock.patch('boto3.Session')
    def test_execute_push_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='test', user='sean', token='abc123')

        self.s3.execute(config, args)
        mock_client.upload_file.assert_called_with("test.pkl", "my-bucket", "test_v1.0.0.pkl")

    @mock.patch('shutil.copyfile')
    @mock.patch('os.remove')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_push_artifactory_all(self, mock_artifactory, mock_remove, mock_copy):
        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='ALL', user='sean', token='abc123')

        self.artifactory.execute(config, args)

        mock_artifactory.assert_has_calls([
            mock.call("artifactory.test.com/ml/test/test_v1.0.0.pkl", auth=('sean', 'abc123')),
            mock.call("artifactory.test.com/ml/test/test2_v1.0.0.pkl", auth=('sean', 'abc123'))
        ], any_order=True)
        mock_copy.assert_has_calls([
            mock.call("test.pkl", "test_v1.0.0.pkl"),
            mock.call("test2.pkl", "test2_v1.0.0.pkl")
        ], any_order=True)
        mock_remove.assert_has_calls([
            mock.call("test_v1.0.0.pkl"),
            mock.call("test2_v1.0.0.pkl")
        ], any_order=True)

    @mock.patch('boto3.Session')
    def test_execute_push_s3_all(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='ALL', user='sean', token='abc124')

        self.s3.execute(config, args)
        mock_client.upload_file.assert_has_calls([
            mock.call("test.pkl", "my-bucket", "test_v1.0.0.pkl"),
            mock.call("test2.pkl", "my-bucket", "test2_v1.0.0.pkl")
        ], any_order=True)

    @mock.patch('boto3.Session')
    def test_execute_push_s3_subfolder(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="push", force=True, artifact='test', user='sean', token='abc123')

        self.s3_subfolder.execute(config, args)
        mock_client.upload_file.assert_called_with("test.pkl", "my-bucket", "sub/folder/test_v1.0.0.pkl")

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('builtins.open')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_pull_artifactory(self, mock_artifactory, mock_open, mock_input):
        mock_input.return_value = "abc"

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="pull", version='0.1.0', artifact='test', user=None, token=None, override=False)

        self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v0.1.0.pkl", auth=("abc", "abc"))
        mock_open.assert_called_with("test_v0.1.0.pkl", "wb")

    @mock.patch('boto3.Session')
    def test_execute_pull_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="pull", version='0.1.0', artifact='test', user=None, token=None, override=False)

        self.s3.execute(config, args)
        mock_client.download_file.assert_called_with("my-bucket", "test_v0.1.0.pkl", "test_v0.1.0.pkl")

    @mock.patch('boto3.Session')
    def test_execute_pull_s3_subfolder(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="pull", version='0.1.0', artifact='test', user=None, token=None, override=False)

        self.s3_subfolder.execute(config, args)
        mock_client.download_file.assert_called_with("my-bucket", "sub/folder/test_v0.1.0.pkl", "test_v0.1.0.pkl")

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('builtins.open')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_pull_lcv_artifactory(self, mock_artifactory, mock_open, mock_input):
        mock_apath = mock_artifactory.return_value
        mock_input.return_value = "abc"
        mock_apath.__iter__.return_value = ["test_v1.1.0", "test_v0.2.4", "test_v1.0.0", "test_v2.0.1"]

        config = sb.objects.config.Config(version="1.0.9")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=False)

        self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v1.0.0.pkl", auth=("abc", "abc"))
        mock_open.assert_called_with("test_v1.0.0.pkl", "wb")

    @mock.patch('boto3.Session')
    def test_execute_pull_lcv_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_client.list_objects_v2.return_value = {"Contents": [{"Key": "test_v1.1.0.pkl"},{"Key": "test_v1.0.5.pkl"},{"Key": "test_v1.0.0.pkl"}]}
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.9")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=False)

        self.s3.execute(config, args)

        mock_client.list_objects_v2.assert_called_with(Bucket="my-bucket", Prefix="test_v1")
        mock_client.download_file.assert_called_with("my-bucket", "test_v1.0.5.pkl", "test_v1.0.5.pkl")

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('builtins.open')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_pull_lcv_not_found_artifactory(self, mock_artifactory, mock_open, mock_input):
        mock_apath = mock_artifactory.return_value
        mock_input.return_value = "abc"
        mock_apath.__iter__.return_value = ["test_v1.1.0", "test_v0.2.4", "test_v1.0.0", "test_v2.0.1"]

        config = sb.objects.config.Config(version="3.0.9")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=False)

        try:
            self.artifactory.execute(config, args)
            self.fail("Exception Not Thrown")
        except RuntimeError as err:
            self.assertEqual(str(err), "No Compatible Version Found")

    @mock.patch('boto3.Session')
    def test_execute_pull_lcv_not_found_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_client.list_objects_v2.return_value = {"Contents": [{"Key": "test_v1.1.0.pkl"},{"Key": "test_v1.0.5.pkl"},{"Key": "test_v1.0.0.pkl"}]}
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="2.0.9")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=False)

        try:
            self.s3.execute(config, args)
            self.fail("Exception Not Thrown")
        except RuntimeError as err:
            self.assertEqual(str(err), "No Compatible Version Found")

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('builtins.open')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_pull_override_and_lcv_artifactory(self, mock_artifactory, mock_open, mock_input):
        mock_apath = mock_artifactory.return_value
        mock_input.return_value = "abc"
        mock_apath.__iter__.return_value = ["test_v1.1.0", "test_v0.2.4", "test_v1.0.0", "test_v2.0.1"]

        config = sb.objects.config.Config(version="0.6.9")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=True)

        self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v0.2.4.pkl", auth=("abc", "abc"))
        mock_open.assert_called_with("test.pkl", "wb")

    @mock.patch('boto3.Session')
    def test_execute_pull_override_and_lcv_s3(self, mock_boto3_session):
        mock_client = mock.Mock()
        mock_session = mock.Mock()
        mock_client.list_objects_v2.return_value = {"Contents": [{"Key": "test_v1.1.0.pkl"},{"Key": "test_v1.0.5.pkl"},{"Key": "test_v1.0.0.pkl"}]}
        mock_session.client.return_value = mock_client
        mock_boto3_session.return_value = mock_session

        config = sb.objects.config.Config(version="1.0.3")
        args = argparse.Namespace(job="pull", version='LATEST', artifact='test', user=None, token=None, override=True)

        self.s3.execute(config, args)

        mock_client.download_file.assert_called_with("my-bucket", "test_v1.0.0.pkl", "test.pkl")

    @mock.patch('skelebot.components.repository.artifactoryRepo.input')
    @mock.patch('artifactory.ArtifactoryPath')
    def test_execute_pull_not_found(self, mock_artifactory, mock_input):
        mock_input.return_value = "abc"
        path = mock_artifactory.return_value
        path.exists.return_value = False

        config = sb.objects.config.Config(version="1.0.0")
        args = argparse.Namespace(job="pull", version='0.1.0', artifact='test', user=None, token=None, override=False)

        self.artifactory.execute(config, args)

        mock_artifactory.assert_called_with("artifactory.test.com/ml/test/test_v0.1.0.pkl", auth=("abc", "abc"))

    def test_validate_valid(self):
        try:
            sb.components.repository.artifactoryRepo.ArtifactoryRepo.validate(self.artifactoryDict)
        except:
            self.fail("Validation Raised Exception Unexpectedly")

        try:
            sb.components.repository.s3Repo.S3Repo.validate(self.s3Dict)
        except:
            self.fail("Validation Raised Exception Unexpectedly")

        try:
            sb.components.repository.repository.Artifact.validate(self.artifactDict)
        except:
            self.fail("Validation Raised Exception Unexpectedly")

    def test_validate_missing(self):
        s3Dict = copy.deepcopy(self.s3Dict)
        del s3Dict['bucket']
        del s3Dict['region']
        del s3Dict['profile']

        try:
            sb.components.repository.s3Repo.S3Repo.validate(s3Dict)
        except SchemaError as error:
            self.assertEqual(error.code, "Missing keys: 'bucket', 'region'")

        artifactoryDict = copy.deepcopy(self.artifactoryDict)
        del artifactoryDict['url']
        del artifactoryDict['repo']
        del artifactoryDict['path']

        try:
            sb.components.repository.artifactoryRepo.ArtifactoryRepo.validate(artifactoryDict)
        except SchemaError as error:
            self.assertEqual(error.code, "Missing keys: 'path', 'repo', 'url'")

        artifactDict = copy.deepcopy(self.artifactDict)
        del artifactDict['name']
        del artifactDict['file']

        try:
            sb.components.repository.repository.Artifact.validate(artifactDict)
        except SchemaError as error:
            self.assertEqual(error.code, "Missing keys: 'file', 'name'")

    def validate_error_s3(self, attr, reset, expected):
        s3Dict = copy.deepcopy(self.s3Dict)
        s3Dict[attr] = reset

        try:
            sb.components.repository.s3Repo.S3Repo.validate(s3Dict)
        except SchemaError as error:
            self.assertEqual(error.code, "S3 '{attr}' must be a {expected}".format(attr=attr, expected=expected))

    def validate_error_artifactory(self, attr, reset, expected):
        artifactoryDict = copy.deepcopy(self.artifactoryDict)
        artifactoryDict[attr] = reset

        try:
            sb.components.repository.artifactoryRepo.ArtifactoryRepo.validate(artifactoryDict)
        except SchemaError as error:
            self.assertEqual(error.code, "Artifactory '{attr}' must be a {expected}".format(attr=attr, expected=expected))

    def validate_error_artifact(self, attr, reset, expected):
        artifactDict = copy.deepcopy(self.artifactDict)
        artifactDict[attr] = reset

        try:
            sb.components.repository.repository.Artifact.validate(artifactDict)
        except SchemaError as error:
            self.assertEqual(error.code, "Artifact '{attr}' must be a {expected}".format(attr=attr, expected=expected))

    def test_invalid(self):
        self.validate_error_s3('bucket', 123, 'String')
        self.validate_error_s3('region', 123, 'String')
        self.validate_error_s3('profile', 123, 'String')
        self.validate_error_artifactory('url', 123, 'String')
        self.validate_error_artifactory('repo', 123, 'String')
        self.validate_error_artifactory('path', 123, 'String')
        self.validate_error_artifact('name', 123, 'String')
        self.validate_error_artifact('file', 123, 'String')

if __name__ == '__main__':
    main()