import os import mock import pytest from six.moves import shlex_quote from mlflow.exceptions import ExecutionException from mlflow.projects._project_spec import EntryPoint from mlflow.utils.file_utils import TempDir, path_to_local_file_uri from tests.projects.utils import load_project, TEST_PROJECT_DIR def test_entry_point_compute_params(): """ Tests that EntryPoint correctly computes a final set of parameters to use when running a project """ project = load_project() entry_point = project.get_entry_point("greeter") # Pass extra "excitement" param, use default value for `greeting` param with TempDir() as storage_dir: params, extra_params = entry_point.compute_parameters( {"name": "friend", "excitement": 10}, storage_dir) assert params == {"name": "friend", "greeting": "hi"} assert extra_params == {"excitement": "10"} # Don't pass extra "excitement" param, pass value for `greeting` params, extra_params = entry_point.compute_parameters( {"name": "friend", "greeting": "hello"}, storage_dir) assert params == {"name": "friend", "greeting": "hello"} assert extra_params == {} # Raise exception on missing required parameter with pytest.raises(ExecutionException): entry_point.compute_parameters({}, storage_dir) def test_entry_point_compute_command(): """ Tests that EntryPoint correctly computes the command to execute in order to run the entry point. """ project = load_project() entry_point = project.get_entry_point("greeter") with TempDir() as tmp: storage_dir = tmp.path() command = entry_point.compute_command({"name": "friend", "excitement": 10}, storage_dir) assert command == "python greeter.py hi friend --excitement 10" with pytest.raises(ExecutionException): entry_point.compute_command({}, storage_dir) # Test shell escaping name_value = "friend; echo 'hi'" command = entry_point.compute_command({"name": name_value}, storage_dir) assert command == "python greeter.py %s %s" % (shlex_quote("hi"), shlex_quote(name_value)) def test_path_parameter(): """ Tests that MLflow file-download APIs get called when necessary for arguments of type `path`. """ project = load_project() entry_point = project.get_entry_point("line_count") with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock: download_uri_mock.return_value = 0 # Verify that we don't attempt to call download_uri when passing a local file to a # parameter of type "path" with TempDir() as tmp: dst_dir = tmp.path() local_path = os.path.join(TEST_PROJECT_DIR, "MLproject") params, _ = entry_point.compute_parameters( user_parameters={"path": local_path}, storage_dir=dst_dir) assert params["path"] == os.path.abspath(local_path) assert download_uri_mock.call_count == 0 params, _ = entry_point.compute_parameters( user_parameters={"path": path_to_local_file_uri(local_path)}, storage_dir=dst_dir) assert params["path"] == os.path.abspath(local_path) assert download_uri_mock.call_count == 0 # Verify that we raise an exception when passing a non-existent local file to a # parameter of type "path" with TempDir() as tmp, pytest.raises(ExecutionException): dst_dir = tmp.path() entry_point.compute_parameters( user_parameters={"path": os.path.join(dst_dir, "some/nonexistent/file")}, storage_dir=dst_dir) # Verify that we do call `download_uri` when passing a URI to a parameter of type "path" for i, prefix in enumerate(["dbfs:/", "s3://", "gs://"]): with TempDir() as tmp: dst_dir = tmp.path() file_to_download = 'images.tgz' download_path = "%s/%s" % (dst_dir, file_to_download) download_uri_mock.return_value = download_path params, _ = entry_point.compute_parameters( user_parameters={"path": os.path.join(prefix, file_to_download)}, storage_dir=dst_dir) assert params["path"] == download_path assert download_uri_mock.call_count == i + 1 def test_uri_parameter(): """Tests parameter resolution for parameters of type `uri`.""" project = load_project() entry_point = project.get_entry_point("download_uri") with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock, TempDir() as tmp: dst_dir = tmp.path() # Test that we don't attempt to locally download parameters of type URI entry_point.compute_command(user_parameters={"uri": "file://%s" % dst_dir}, storage_dir=dst_dir) assert download_uri_mock.call_count == 0 # Test that we raise an exception if a local path is passed to a parameter of type URI with pytest.raises(ExecutionException): entry_point.compute_command(user_parameters={"uri": dst_dir}, storage_dir=dst_dir) def test_params(): defaults = { "alpha": "float", "l1_ratio": {"type": "float", "default": 0.1}, "l2_ratio": {"type": "float", "default": 0.0003}, "random_str": {"type": "string", "default": "hello"}, } entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py") user1 = {} with pytest.raises(ExecutionException): entry_point._validate_parameters(user1) user_2 = {"beta": 0.004} with pytest.raises(ExecutionException): entry_point._validate_parameters(user_2) user_3 = {"alpha": 0.004, "gamma": 0.89} expected_final_3 = {"alpha": '0.004', "l1_ratio": '0.1', "l2_ratio": '0.0003', "random_str": "hello"} expected_extra_3 = {"gamma": "0.89"} final_3, extra_3 = entry_point.compute_parameters(user_3, None) assert expected_extra_3 == extra_3 assert expected_final_3 == final_3 user_4 = {"alpha": 0.004, "l1_ratio": 0.0008, "random_str_2": "hello"} expected_final_4 = {"alpha": '0.004', "l1_ratio": '0.0008', "l2_ratio": '0.0003', "random_str": "hello"} expected_extra_4 = {"random_str_2": "hello"} final_4, extra_4 = entry_point.compute_parameters(user_4, None) assert expected_extra_4 == extra_4 assert expected_final_4 == final_4 user_5 = {"alpha": -0.99, "random_str": "hi"} expected_final_5 = {"alpha": '-0.99', "l1_ratio": '0.1', "l2_ratio": '0.0003', "random_str": "hi"} expected_extra_5 = {} final_5, extra_5 = entry_point.compute_parameters(user_5, None) assert expected_final_5 == final_5 assert expected_extra_5 == extra_5 user_6 = {"alpha": 0.77, "ALPHA": 0.89} expected_final_6 = {"alpha": '0.77', "l1_ratio": '0.1', "l2_ratio": '0.0003', "random_str": "hello"} expected_extra_6 = {"ALPHA": "0.89"} final_6, extra_6 = entry_point.compute_parameters(user_6, None) assert expected_extra_6 == extra_6 assert expected_final_6 == final_6 def test_path_params(): data_file = "s3://path.test/resources/data_file.csv" defaults = { "constants": {"type": "uri", "default": "s3://path.test/b1"}, "data": {"type": "path", "default": data_file} } entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py") with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock: final_1, extra_1 = entry_point.compute_parameters({}, None) assert (final_1 == {"constants": "s3://path.test/b1", "data": data_file}) assert (extra_1 == {}) assert download_uri_mock.call_count == 0 with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock: user_2 = {"alpha": 0.001, "constants": "s3://path.test/b_two"} final_2, extra_2 = entry_point.compute_parameters(user_2, None) assert (final_2 == {"constants": "s3://path.test/b_two", "data": data_file}) assert (extra_2 == {"alpha": "0.001"}) assert download_uri_mock.call_count == 0 with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock, TempDir() as tmp: dest_path = tmp.path() download_path = "%s/data_file.csv" % dest_path download_uri_mock.return_value = download_path user_3 = {"alpha": 0.001} final_3, extra_3 = entry_point.compute_parameters(user_3, dest_path) assert (final_3 == {"constants": "s3://path.test/b1", "data": download_path}) assert (extra_3 == {"alpha": "0.001"}) assert download_uri_mock.call_count == 1 with mock.patch("mlflow.tracking.artifact_utils._download_artifact_from_uri") \ as download_uri_mock, TempDir() as tmp: dest_path = tmp.path() download_path = "%s/images.tgz" % dest_path download_uri_mock.return_value = download_path user_4 = {"data": "s3://another.example.test/data_stash/images.tgz"} final_4, extra_4 = entry_point.compute_parameters(user_4, dest_path) assert (final_4 == {"constants": "s3://path.test/b1", "data": download_path}) assert (extra_4 == {}) assert download_uri_mock.call_count == 1