# Copyright 2019 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test cases for the firebase_admin.ml module.""" import json import pytest import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' CREATE_TIME = '2020-01-21T20:44:27.392932Z' CREATE_TIME_MILLIS = 1579639467392 UPDATE_TIME = '2020-01-21T22:45:29.392932Z' UPDATE_TIME_MILLIS = 1579646729392 CREATE_TIME_2 = '2020-01-21T21:44:27.392932Z' UPDATE_TIME_2 = '2020-01-21T23:45:29.392932Z' ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' MODEL_HASH = '987987a98b98798d098098e09809fc0893897' TAG_1 = 'Tag1' TAG_2 = 'Tag2' TAG_3 = 'Tag3' TAGS = [TAG_1, TAG_2] TAGS_2 = [TAG_1, TAG_3] MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) DISPLAY_NAME_2 = 'displayName2' MODEL_JSON_2 = { 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) DISPLAY_NAME_3 = 'displayName3' MODEL_JSON_3 = { 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } MODEL_3 = ml.Model.from_dict(MODEL_JSON_3) MODEL_STATE_PUBLISHED_JSON = { 'published': True } VALIDATION_ERROR_CODE = 400 VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) MODEL_STATE_ERROR_JSON = { 'validationError': { 'code': VALIDATION_ERROR_CODE, 'message': VALIDATION_ERROR_MSG, } } OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) OPERATION_NOT_DONE_JSON_1 = { 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' } } GCS_BUCKET_NAME = 'my_bucket' GCS_BLOB_NAME = 'mymodel.tflite' GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { 'gcsTfliteUri': GCS_TFLITE_URI, 'sizeBytes': '1234567' } TFLITE_FORMAT = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) GCS_TFLITE_SIGNED_URI_PATTERN = ( 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} GCS_TFLITE_MODEL_SOURCE_2 = ml.TFLiteGCSModelSource(GCS_TFLITE_URI_2) TFLITE_FORMAT_JSON_2 = { 'gcsTfliteUri': GCS_TFLITE_URI_2, 'sizeBytes': '2345678' } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME, 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_ERROR_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, 'tags': TAGS, } CREATED_UPDATED_MODEL_1 = ml.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME, 'updateTime': UPDATE_TIME, 'tags': TAGS, 'activeOperations': [OPERATION_NOT_DONE_JSON_1] } LOCKED_MODEL_JSON_2 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_2, 'createTime': CREATE_TIME_2, 'updateTime': UPDATE_TIME_2, 'tags': TAGS_2, 'activeOperations': [OPERATION_NOT_DONE_JSON_1] } OPERATION_DONE_MODEL_JSON_1 = { 'done': True, 'response': CREATED_UPDATED_MODEL_JSON_1 } OPERATION_MALFORMED_JSON_1 = { 'done': True, # if done is true then either response or error should be populated } OPERATION_MISSING_NAME = { # Name is required if the operation is not done. 'done': False } OPERATION_ERROR_CODE = 3 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { 'done': True, 'error': { 'code': OPERATION_ERROR_CODE, 'message': OPERATION_ERROR_MSG, } } FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME, 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_ERROR_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, 'tags': TAGS, 'activeOperations': [OPERATION_NOT_DONE_JSON_1], } FULL_MODEL_PUBLISHED_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME, 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_PUBLISHED_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } FULL_MODEL_PUBLISHED = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { 'name': OPERATION_NAME_1, 'done': True, 'response': FULL_MODEL_PUBLISHED_JSON } EMPTY_RESPONSE = json.dumps({}) OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) NO_MODELS_LIST_RESPONSE = json.dumps({}) DEFAULT_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_1, MODEL_JSON_2], 'nextPageToken': NEXT_PAGE_TOKEN }) LAST_PAGE_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_3] }) ONE_PAGE_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_1, MODEL_JSON_2, MODEL_JSON_3], }) ERROR_CODE_NOT_FOUND = 404 ERROR_MSG_NOT_FOUND = 'The resource was not found' ERROR_STATUS_NOT_FOUND = 'NOT_FOUND' ERROR_JSON_NOT_FOUND = { 'error': { 'code': ERROR_CODE_NOT_FOUND, 'message': ERROR_MSG_NOT_FOUND, 'status': ERROR_STATUS_NOT_FOUND } } ERROR_RESPONSE_NOT_FOUND = json.dumps(ERROR_JSON_NOT_FOUND) ERROR_CODE_BAD_REQUEST = 400 ERROR_MSG_BAD_REQUEST = 'Invalid Argument' ERROR_STATUS_BAD_REQUEST = 'INVALID_ARGUMENT' ERROR_JSON_BAD_REQUEST = { 'error': { 'code': ERROR_CODE_BAD_REQUEST, 'message': ERROR_MSG_BAD_REQUEST, 'status': ERROR_STATUS_BAD_REQUEST } } ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) INVALID_MODEL_ID_ARGS = [ ('', ValueError), ('&_*#@:/?', ValueError), (None, TypeError), (12345, TypeError), ] INVALID_MODEL_ARGS = [ 'abc', 4.2, list(), dict(), True, -1, 0, None ] INVALID_OP_NAME_ARGS = [ 'abc', '123', 'operations/project/1234/model/abc/operation/123', 'projects/operations/123', 'projects/$#@/operations/123', 'projects/1234/operations/123/extrathing', ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(ml._MAX_PAGE_SIZE) INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] # For validation type errors def check_error(excinfo, err_type, msg=None): err = excinfo.value assert isinstance(err, err_type) if msg: assert str(err) == msg # For errors that are returned in an operation def check_operation_error(excinfo, code, msg): err = excinfo.value assert isinstance(err, exceptions.FirebaseError) assert err.code == code assert str(err) == msg # For rpc errors def check_firebase_error(excinfo, code, status, msg): err = excinfo.value assert isinstance(err, exceptions.FirebaseError) assert err.code == code assert err.http_response is not None assert err.http_response.status_code == status assert str(err) == msg def instrument_ml_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() ml_service = ml._get_ml_service(app) recorder = [] session_url = 'https://firebaseml.googleapis.com/v1beta2/' if isinstance(status, list): adapter = testutils.MockMultiRequestAdapter else: adapter = testutils.MockAdapter if operations: ml_service._operation_client.session.mount( session_url, adapter(payload, status, recorder)) else: ml_service._client.session.mount( session_url, adapter(payload, status, recorder)) return recorder class _TestStorageClient: @staticmethod def upload(bucket_name, model_file_name, app): del app # unused variable blob_name = ml._CloudStorageClient.BLOB_NAME.format(model_file_name) return ml._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): del app # unused variable bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel: """Tests ml.Model class.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test ml.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _op_url(project_id): return BASE_URL + \ 'projects/{0}/operations/123'.format(project_id) def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_MILLIS assert model.update_time == UPDATE_TIME_MILLIS assert model.validation_error == VALIDATION_ERROR_MSG assert model.published is False assert model.etag == ETAG assert model.model_hash == MODEL_HASH assert model.tags == TAGS assert model.locked is True assert model.model_format is None assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_MILLIS assert model.update_time == UPDATE_TIME_MILLIS assert model.validation_error is None assert model.published is True assert model.etag == ETAG assert model.model_hash == MODEL_HASH assert model.tags == TAGS assert model.locked is False assert model.model_format == TFLITE_FORMAT assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON def test_model_keyword_based_creation_and_setters(self): model = ml.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) assert model.display_name == DISPLAY_NAME_1 assert model.tags == TAGS assert model.model_format == TFLITE_FORMAT assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } model.display_name = DISPLAY_NAME_2 model.tags = TAGS_2 model.model_format = TFLITE_FORMAT_2 assert model.as_dict() == { 'displayName': DISPLAY_NAME_2, 'tags': TAGS_2, 'tfliteModel': TFLITE_FORMAT_JSON_2 } def test_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_URI } } def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") assert model_source.as_dict() == { 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } def test_model_source_setters(self): model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 assert model_format.as_dict() == { 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_URI_2 } } def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict(for_upload=True) == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI } } @pytest.mark.parametrize('helper_func', [ ml.TFLiteGCSModelSource.from_keras_model, ml.TFLiteGCSModelSource.from_saved_model ]) def test_tf_not_enabled(self, helper_func): ml._TF_ENABLED = False # for reliability with pytest.raises(ImportError) as excinfo: helper_func(None) check_error(excinfo, ImportError) @pytest.mark.parametrize('display_name, exc_type', [ ('', ValueError), ('&_*#@:/?', ValueError), (12345, TypeError) ]) def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as excinfo: ml.Model(display_name=display_name) check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ ('tag1', TypeError, 'Tags must be a list of strings.'), (123, TypeError, 'Tags must be a list of strings.'), (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), (['', 'tag2'], ValueError, 'Tag format is invalid.'), (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'tag2'], ValueError, 'Tag format is invalid.') ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): with pytest.raises(exc_type) as excinfo: ml.Model(tags=tags) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ 123, "abc", {}, [], True ]) def test_model_format_validation_errors(self, model_format): with pytest.raises(TypeError) as excinfo: ml.Model(model_format=model_format) check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ 123, "abc", {}, [], True ]) def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as excinfo: ml.TFLiteFormat(model_source=model_source) check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ (123, TypeError), ('abc', ValueError), ('gs://NO_CAPITALS', ValueError), ('gs://abc/', ValueError), ('gs://aa/model.tflite', ValueError), ('gs://@#$%/model.tflite', ValueError), ('gs://invalid space/model.tflite', ValueError), ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', ValueError) ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as excinfo: ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked() def test_wait_for_unlocked(self): recorder = instrument_ml_service(status=200, operations=True, payload=OPERATION_DONE_PUBLISHED_RESPONSE) model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) model.wait_for_unlocked() assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestModel._op_url(PROJECT_ID) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) with pytest.raises(Exception) as excinfo: model.wait_for_unlocked(max_time_seconds=0.1) check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') assert len(recorder) == 1 class TestCreateModel: """Tests ml.create_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _url(project_id): return BASE_URL + 'projects/{0}/models'.format(project_id) @staticmethod def _op_url(project_id): return BASE_URL + \ 'projects/{0}/operations/123'.format(project_id) @staticmethod def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = ml.create_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = ml.create_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST ) assert len(create_recorder) == 1 @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: ml.create_model(model) check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: ml.create_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: ml.create_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestUpdateModel: """Tests ml.update_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod def _op_url(project_id): return BASE_URL + \ 'projects/{0}/operations/123'.format(project_id) def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = ml.update_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = ml.update_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST ) assert len(create_recorder) == 1 @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: ml.update_model(model) check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: ml.update_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: ml.update_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestPublishUnpublish: """Tests ml.publish_model and ml.unpublish_model.""" PUBLISH_UNPUBLISH_WITH_ARGS = [ (ml.publish_model, True), (ml.unpublish_model, False) ] PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _update_url(project_id, model_id): update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( project_id, model_id) return BASE_URL + update_url @staticmethod def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod def _op_url(project_id): return BASE_URL + \ 'projects/{0}/operations/123'.format(project_id) @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): recorder = instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = publish_function(MODEL_ID_1) assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = publish_function(MODEL_ID_1) assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST ) assert len(create_recorder) == 1 class TestGetModel: """Tests ml.get_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_get_model(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) model = ml.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: ml.get_model(model_id) check_error(excinfo, exc_type) def test_get_model_error(self): recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: ml.get_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): ml.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestDeleteModel: """Tests ml.delete_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_delete_model(self): recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) ml.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: ml.delete_model(model_id) check_error(excinfo, exc_type) def test_delete_model_error(self): recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: ml.delete_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): ml.delete_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestListModels: """Tests ml.list_models.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) @classmethod def teardown_class(cls): testutils.cleanup_apps() @staticmethod def _url(project_id): return BASE_URL + 'projects/{0}/models'.format(project_id) @staticmethod def _check_page(page, model_count): assert isinstance(page, ml.ListModelsPage) assert len(page.models) == model_count for model in page.models: assert isinstance(model, ml.Model) def test_list_models_no_args(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) models_page = ml.list_models() assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN assert models_page.models[0] == MODEL_1 assert models_page.models[1] == MODEL_2 def test_list_models_with_all_args(self): recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) models_page = ml.list_models( 'display_name=displayName3', page_size=10, page_token=PAGE_TOKEN) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 assert not models_page.has_next_page @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as excinfo: ml.list_models(list_filter=list_filter) check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), (4.2, TypeError, 'Page size must be a number or None.'), (list(), TypeError, 'Page size must be a number or None.'), (dict(), TypeError, 'Page size must be a number or None.'), (True, TypeError, 'Page size must be a number or None.'), (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (ml._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as excinfo: ml.list_models(page_size=page_size) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as excinfo: ml.list_models(page_token=page_token) check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): recorder = instrument_ml_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(exceptions.InvalidArgumentError) as excinfo: ml.list_models() check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST ) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): ml.list_models(app=app) testutils.run_without_project_id(evaluate) def test_list_single_page(self): recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) models_page = ml.list_models() assert len(recorder) == 1 assert models_page.next_page_token == '' assert models_page.has_next_page is False assert models_page.get_next_page() is None models = [model for model in models_page.iterate_all()] assert len(models) == 1 def test_list_multiple_pages(self): # Page 1 recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 2 assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True # Page 2 recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) page_2 = page.get_next_page() assert len(recorder) == 1 assert len(page_2.models) == 1 assert page_2.next_page_token == '' assert page_2.has_next_page is False assert page_2.get_next_page() is None def test_list_models_paged_iteration(self): # Page 1 recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) page = ml.list_models() assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True iterator = page.iterate_all() for index in range(2): model = next(iterator) assert model.display_name == 'displayName{0}'.format(index+1) assert len(recorder) == 1 # Page 2 recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) model = next(iterator) assert model.display_name == DISPLAY_NAME_3 with pytest.raises(StopIteration): next(iterator) def test_list_models_stop_iteration(self): recorder = instrument_ml_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 3 iterator = page.iterate_all() models = [model for model in iterator] assert len(page.models) == 3 with pytest.raises(StopIteration): next(iterator) assert len(models) == 3 def test_list_models_no_models(self): recorder = instrument_ml_service(status=200, payload=NO_MODELS_LIST_RESPONSE) page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 0 models = [model for model in page.iterate_all()] assert len(models) == 0