import mock import os import pytest from six.moves import reload_module as reload import mlflow from mlflow.store.db.db_types import DATABASE_ENGINES from mlflow.store.tracking.file_store import FileStore from mlflow.store.tracking.rest_store import RestStore from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry from mlflow.tracking._tracking_service.utils import _get_store, _TRACKING_URI_ENV_VAR, \ _TRACKING_USERNAME_ENV_VAR, _TRACKING_PASSWORD_ENV_VAR, _TRACKING_TOKEN_ENV_VAR, \ _TRACKING_INSECURE_TLS_ENV_VAR # pylint: disable=unused-argument # Disable mocking tracking URI here, as we want to test setting the tracking URI via # environment variable. See # http://doc.pytest.org/en/latest/skipping.html#skip-all-test-functions-of-a-class-or-module # and https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst#writing-python-tests # for more information. pytestmark = pytest.mark.notrackingurimock def test_get_store_file_store(tmp_wkdir): env = {} with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, FileStore) assert os.path.abspath(store.root_directory) == os.path.abspath("mlruns") def test_get_store_file_store_from_arg(tmp_wkdir): env = {} with mock.patch.dict(os.environ, env): store = _get_store("other/path") assert isinstance(store, FileStore) assert os.path.abspath(store.root_directory) == os.path.abspath("other/path") @pytest.mark.parametrize("uri", ["other/path", "file:other/path"]) def test_get_store_file_store_from_env(tmp_wkdir, uri): env = { _TRACKING_URI_ENV_VAR: uri } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, FileStore) assert os.path.abspath(store.root_directory) == os.path.abspath("other/path") def test_get_store_basic_rest_store(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050" } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().host == "https://my-tracking-server:5050" assert store.get_host_creds().token is None def test_get_store_rest_store_with_password(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050", _TRACKING_USERNAME_ENV_VAR: "Bob", _TRACKING_PASSWORD_ENV_VAR: "Ross", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().host == "https://my-tracking-server:5050" assert store.get_host_creds().username == "Bob" assert store.get_host_creds().password == "Ross" def test_get_store_rest_store_with_token(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050", _TRACKING_TOKEN_ENV_VAR: "my-token", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().token == "my-token" def test_get_store_rest_store_with_insecure(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050", _TRACKING_INSECURE_TLS_ENV_VAR: "true", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().ignore_tls_verification def test_get_store_rest_store_with_no_insecure(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050", _TRACKING_INSECURE_TLS_ENV_VAR: "false", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert not store.get_host_creds().ignore_tls_verification # By default, should not ignore verification. env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert not store.get_host_creds().ignore_tls_verification @pytest.mark.parametrize("db_type", DATABASE_ENGINES) def test_get_store_sqlalchemy_store(tmp_wkdir, db_type): patch_create_engine = mock.patch("sqlalchemy.create_engine") uri = "{}://hostname/database".format(db_type) env = { _TRACKING_URI_ENV_VAR: uri } with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine,\ mock.patch("mlflow.store.db.utils._verify_schema"), \ mock.patch("mlflow.store.db.utils._initialize_tables"): store = _get_store() assert isinstance(store, SqlAlchemyStore) assert store.db_uri == uri assert store.artifact_root_uri == "./mlruns" mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True) @pytest.mark.parametrize("db_type", DATABASE_ENGINES) def test_get_store_sqlalchemy_store_with_artifact_uri(tmp_wkdir, db_type): patch_create_engine = mock.patch("sqlalchemy.create_engine") uri = "{}://hostname/database".format(db_type) env = { _TRACKING_URI_ENV_VAR: uri } artifact_uri = "file:artifact/path" with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, \ mock.patch("mlflow.store.db.utils._verify_schema"), \ mock.patch("mlflow.store.db.utils._initialize_tables"): store = _get_store(artifact_uri=artifact_uri) assert isinstance(store, SqlAlchemyStore) assert store.db_uri == uri assert store.artifact_root_uri == artifact_uri mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True) def test_get_store_databricks(): env = { _TRACKING_URI_ENV_VAR: "databricks", 'DATABRICKS_HOST': "https://my-tracking-server", 'DATABRICKS_TOKEN': "abcdef", } with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().host == "https://my-tracking-server" assert store.get_host_creds().token == "abcdef" def test_get_store_databricks_profile(): env = { _TRACKING_URI_ENV_VAR: "databricks://mycoolprofile", } # It's kind of annoying to setup a profile, and we're not really trying to test # that anyway, so just check if we raise a relevant exception. with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) with pytest.raises(Exception) as e_info: store.get_host_creds() assert 'mycoolprofile' in str(e_info.value) def test_standard_store_registry_with_mocked_entrypoint(): mock_entrypoint = mock.Mock() mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ): # Entrypoints are registered at import time, so we need to reload the # module to register the entrypoint given by the mocked # extrypoints.get_group_all reload(mlflow.tracking._tracking_service.utils) expected_standard_registry = { '', 'file', 'http', 'https', 'postgresql', 'mysql', 'sqlite', 'mssql', 'databricks', 'mock-scheme' } assert expected_standard_registry.issubset( mlflow.tracking._tracking_service.utils._tracking_store_registry._registry.keys() ) @pytest.mark.large def test_standard_store_registry_with_installed_plugin(tmp_wkdir): """This test requires the package in tests/resources/mlflow-test-plugin to be installed""" reload(mlflow.tracking._tracking_service.utils) assert "file-plugin" in \ mlflow.tracking._tracking_service.utils._tracking_store_registry._registry.keys() from mlflow_test_plugin.file_store import PluginFileStore env = { _TRACKING_URI_ENV_VAR: "file-plugin:test-path", } with mock.patch.dict(os.environ, env): plugin_file_store = mlflow.tracking._tracking_service.utils._get_store() assert isinstance(plugin_file_store, PluginFileStore) assert plugin_file_store.is_plugin def test_plugin_registration(): tracking_store = TrackingStoreRegistry() test_uri = "mock-scheme://fake-host/fake-path" test_scheme = "mock-scheme" mock_plugin = mock.Mock() tracking_store.register(test_scheme, mock_plugin) assert test_scheme in tracking_store._registry assert tracking_store.get_store(test_uri) == mock_plugin.return_value mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None) def test_plugin_registration_via_entrypoints(): mock_plugin_function = mock.Mock() mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function)) mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ) as mock_get_group_all: tracking_store = TrackingStoreRegistry() tracking_store.register_entrypoints() assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None) mock_get_group_all.assert_called_once_with("mlflow.tracking_store") @pytest.mark.parametrize("exception", [AttributeError("test exception"), ImportError("test exception")]) def test_handle_plugin_registration_failure_via_entrypoints(exception): mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception)) mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ) as mock_get_group_all: tracking_store = TrackingStoreRegistry() # Check that the raised warning contains the message from the original exception with pytest.warns(UserWarning, match="test exception"): tracking_store.register_entrypoints() mock_entrypoint.load.assert_called_once() mock_get_group_all.assert_called_once_with("mlflow.tracking_store") def test_get_store_for_unregistered_scheme(): tracking_store = TrackingStoreRegistry() with pytest.raises(UnsupportedModelRegistryStoreURIException): tracking_store.get_store("unknown-scheme://")