# Copyright (C) 2018 British Broadcasting Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from requests.compat import json import git import jsonschema import traceback import inspect import uuid from . import TestHelper from .Specification import Specification from .TestResult import Test from . import Config as CONFIG NMOS_WIKI_URL = "https://github.com/AMWA-TV/nmos/wiki" def test_depends(func): """Decorator to prevent a test being executed in individual mode""" def invalid(self, test): if self.test_individual: test.description = "Invalid" return test.DISABLED("This test cannot be performed individually") else: return func(self, test) invalid.__name__ = func.__name__ invalid.__doc__ = func.__doc__ return invalid class NMOSTestException(Exception): """Provides a way to exit a single test, by providing the TestResult return statement as the first exception parameter""" pass class NMOSInitException(Exception): """The test set was run in an invalid mode. Causes all tests to abort""" pass class GenericTest(object): """ Generic testing class. Can be inherited from in order to perform detailed testing. """ def __init__(self, apis, omit_paths=None, disable_auto=False): self.apis = apis self.saved_entities = {} self.auto_test_count = 0 self.test_individual = False self.result = list() self.protocol = "http" self.ws_protocol = "ws" if CONFIG.ENABLE_HTTPS: self.protocol = "https" self.ws_protocol = "wss" self.omit_paths = [] if isinstance(omit_paths, list): self.omit_paths = omit_paths self.disable_auto = disable_auto test = Test("Test initialisation") for api_name, api_data in self.apis.items(): if "spec_path" not in api_data or api_data["version"] is None: continue repo = git.Repo(api_data["spec_path"]) # List remote branches and check there is a v#.#.x or v#.#-dev branches = repo.git.branch('-a') spec_branch = None branch_names = [api_data["version"] + ".x", api_data["version"] + "-dev"] for branch in branch_names: if "remotes/origin/" + branch in branches: spec_branch = branch break if not spec_branch: raise Exception("No branch matching the expected patterns was found in the Git repository") api_data["spec_branch"] = spec_branch repo.git.reset('--hard') repo.git.checkout(spec_branch) repo.git.rebase("origin/" + spec_branch) self.parse_RAML() self.result.append(test.NA("")) def parse_RAML(self): """Create a Specification object for each API defined in this object""" for api in self.apis: if "spec_path" not in self.apis[api]: continue raml_path = os.path.join(self.apis[api]["spec_path"] + '/APIs/' + self.apis[api]["raml"]) self.apis[api]["spec"] = Specification(raml_path) def execute_tests(self, test_names): """Perform tests defined within this class""" for test_name in test_names: self.execute_test(test_name) def execute_test(self, test_name): """Perform a test defined within this class""" self.test_individual = (test_name != "all") # Run automatically defined tests if test_name in ["auto", "all"] and not self.disable_auto: print(" * Running basic API tests") self.result += self.basics() # Run manually defined tests if test_name == "all": for method_name in dir(self): if method_name.startswith("test_"): method = getattr(self, method_name) if callable(method): print(" * Running " + method_name) test = Test(inspect.getdoc(method), method_name) try: self.result.append(method(test)) except NMOSTestException as e: self.result.append(e.args[0]) except Exception as e: self.result.append(self.uncaught_exception(method_name, e)) # Run a single test if test_name != "auto" and test_name != "all": method = getattr(self, test_name) if callable(method): print(" * Running " + test_name) test = Test(inspect.getdoc(method), test_name) try: self.result.append(method(test)) except NMOSTestException as e: self.result.append(e.args[0]) except Exception as e: self.result.append(self.uncaught_exception(test_name, e)) def uncaught_exception(self, test_name, exception): """Print a traceback and provide a test FAIL result for uncaught exceptions""" traceback.print_exc() test = Test("Error executing {}".format(test_name), test_name) return test.FAIL("Uncaught exception. Please report the traceback from the terminal to " "https://github.com/amwa-tv/nmos-testing/issues. {}".format(exception)) def set_up_tests(self): """Called before a set of tests is run. Override this method with setup code.""" pass def tear_down_tests(self): """Called after a set of tests is run. Override this method with teardown code.""" pass def run_tests(self, test_name=["all"]): """Perform tests and return the results as a list""" # Set up test = Test("Test setup", "set_up_tests") if CONFIG.PREVALIDATE_API: for api in self.apis: if "spec_path" not in self.apis[api] or self.apis[api]["url"] is None: continue valid, response = self.do_request("GET", self.apis[api]["url"]) if not valid or response.status_code != 200: raise NMOSInitException("No API found at {}".format(self.apis[api]["url"])) self.set_up_tests() self.result.append(test.NA("")) # Run tests self.execute_tests(test_name) # Tear down test = Test("Test teardown", "tear_down_tests") self.tear_down_tests() self.result.append(test.NA("")) return self.result def convert_bytes(self, data): """Convert bytes which may be contained within a dict or tuple into strings""" if isinstance(data, bytes): return data.decode('ascii') if isinstance(data, dict): return dict(map(self.convert_bytes, data.items())) if isinstance(data, tuple): return map(self.convert_bytes, data) return data def prepare_CORS(self, method, request_headers): """Prepare CORS headers to be used when making any API request""" headers = {} headers['Access-Control-Request-Method'] = method # Match to request type headers['Access-Control-Request-Headers'] = ", ".join(request_headers) return headers # 'check' functions return a Boolean pass/fail indicator and a message def check_CORS(self, method, headers, expect_methods=None, expect_headers=None): """Check the CORS headers returned by an API call""" if 'Access-Control-Allow-Origin' not in headers: return False, "'Access-Control-Allow-Origin' not in CORS headers: {}".format(headers) if method.upper() == "OPTIONS" and expect_headers is not None: if 'Access-Control-Allow-Headers' not in headers: return False, "'Access-Control-Allow-Headers' not in CORS headers: {}".format(headers) current_headers = [x.strip().upper() for x in headers['Access-Control-Allow-Headers'].split(",")] for cors_header in expect_headers: if cors_header.upper() not in current_headers: return False, "'{}' not in 'Access-Control-Allow-Headers' CORS header: {}" \ .format(cors_header, headers['Access-Control-Allow-Headers']) if method.upper() == "OPTIONS" and expect_methods is not None: if 'Access-Control-Allow-Methods' not in headers: return False, "'Access-Control-Allow-Methods' not in CORS headers: {}".format(headers) current_methods = [x.strip().upper() for x in headers['Access-Control-Allow-Methods'].split(",")] for cors_method in expect_methods: if cors_method.upper() not in current_methods: return False, "'{}' not in 'Access-Control-Allow-Methods' CORS header: {}" \ .format(cors_method, headers['Access-Control-Allow-Methods']) return True, "" def check_content_type(self, headers, expected_type="application/json"): """Check the Content-Type header of an API request or response""" if "Content-Type" not in headers: return False, "API failed to signal a Content-Type." else: ctype = headers["Content-Type"] ctype_params = ctype.split(";") if ctype_params[0] != expected_type: return False, "API signalled a Content-Type of {} rather than {}." \ .format(ctype, expected_type) elif len(ctype_params) == 2 and ctype_params[1].strip().lower() == "charset=utf-8": return True, "API signalled an unnecessary 'charset' in its Content-Type: {}" \ .format(ctype) elif len(ctype_params) >= 2: return False, "API signalled unexpected additional parameters in its Content-Type: {}" \ .format(ctype) return True, "" def check_accept(self, headers): """Check the Accept header of an API request""" if "Accept" in headers: accept_params = headers["Accept"].split(",") max_weight = 0 max_weight_types = [] for param in accept_params: param_parts = param.split(";") media_type = param_parts[0].strip() weight = 1 for ext_param in param_parts[1:]: if ext_param.strip().startswith("q="): try: weight = float(ext_param.split("=")[1].strip()) break except Exception: pass if weight > max_weight: max_weight = weight max_weight_types.clear() if weight == max_weight: max_weight_types.append(media_type) if "application/json" not in max_weight_types and "*/*" not in max_weight_types: return False, "API did not signal a preference for application/json via its Accept header." try: max_weight_types.remove("application/json") except ValueError: pass try: max_weight_types.remove("*/*") except ValueError: pass if len(max_weight_types) > 0: return False, "API signalled multiple media types with the same preference as application/json in " \ "its Accept header: {}".format(max_weight_types) return True, "" def auto_test_name(self, api_name): """Get the name which should be used for an automatically defined test""" self.auto_test_count += 1 return "auto_{}_{}".format(api_name, self.auto_test_count) # 'do_test' functions either return a TestResult, or raise an NMOSTestException when there's an error def do_test_base_path(self, api_name, base_url, path, expectation): """Check that a GET to a path returns a JSON array containing a defined string""" test = Test("GET {}".format(path), self.auto_test_name(api_name)) valid, response = self.do_request("GET", base_url + path) if not valid: return test.FAIL("Unable to connect to API: {}".format(response)) if response.status_code != 200: return test.FAIL("Incorrect response code: {}".format(response.status_code)) else: cors_valid, cors_message = self.check_CORS('GET', response.headers) if not cors_valid: return test.FAIL(cors_message) try: if not isinstance(response.json(), list) or expectation not in response.json(): return test.FAIL("Response is not an array containing '{}'".format(expectation)) else: return test.PASS() except json.JSONDecodeError: return test.FAIL("Non-JSON response returned") def check_response(self, schema, method, response): """Confirm that a given Requests response conforms to the expected schema and has any expected headers""" ctype_valid, ctype_message = self.check_content_type(response.headers) if not ctype_valid: return False, ctype_message cors_valid, cors_message = self.check_CORS(method, response.headers) if not cors_valid: return False, cors_message try: self.validate_schema(response.json(), schema) except jsonschema.ValidationError: return False, "Response schema validation error" except json.JSONDecodeError: return False, "Invalid JSON received" return True, ctype_message def check_error_response(self, method, response, code): """Confirm that a given Requests response conforms to the 4xx/5xx error schema and has any expected headers""" schema = TestHelper.load_resolved_schema("test_data/core", "error.json", path_prefix=False) valid, message = self.check_response(schema, method, response) if valid: if response.json()["code"] != code: return False, "Error JSON 'code' was not set to {}".format(code) return True, "" else: return False, message def validate_schema(self, payload, schema): """ Validate the payload under the given schema. Raises an exception if the payload (or schema itself) is invalid """ checker = jsonschema.FormatChecker(["ipv4", "ipv6", "uri"]) jsonschema.validate(payload, schema, format_checker=checker) def do_request(self, method, url, **kwargs): return TestHelper.do_request(method=method, url=url, **kwargs) def basics(self): """Perform basic API read requests (GET etc.) relevant to all API definitions""" results = [] for api in sorted(self.apis.keys()): if "spec_path" not in self.apis[api]: continue if self.apis[api]["url"] is None: continue # Set the auto test count to zero as each test name includes the API type self.auto_test_count = 0 # We don't check the very base of the URL (before x-nmos) as it may be used for other things results.append(self.do_test_base_path(api, self.apis[api]["base_url"], "/x-nmos", api + "/")) results.append(self.do_test_base_path(api, self.apis[api]["base_url"], "/x-nmos/{}".format(api), self.apis[api]["version"] + "/")) for resource in self.apis[api]["spec"].get_reads(): for response_code in resource[1]['responses']: if response_code == 200 and resource[0] not in self.omit_paths: # TODO: Test for each of these if the trailing slash version also works and if redirects are # used on either. result = self.do_test_api_resource(resource, response_code, api) if result is not None: results.append(result) # Perform an automatic check for an error condition results.append(self.do_test_404_path(api)) return results def do_test_404_path(self, api_name): api = self.apis[api_name] error_code = 404 invalid_path = str(uuid.uuid4()) url = "{}/{}".format(api["url"].rstrip("/"), invalid_path) test = Test("GET /x-nmos/{}/{}/{} ({})".format(api_name, api["version"], invalid_path, error_code), self.auto_test_name(api_name)) valid, response = self.do_request("GET", url) if not valid: return test.FAIL(response) if response.status_code != error_code: return test.FAIL("Incorrect response code, expected {}: {}".format(error_code, response.status_code)) valid, message = self.check_error_response("GET", response, error_code) if valid: return test.PASS() else: return test.FAIL(message) def do_test_api_resource(self, resource, response_code, api): # Test URLs which include a {resourceId} or similar parameter if resource[1]['params'] and len(resource[1]['params']) == 1: path = resource[0].split("{")[0].rstrip("/") if path in self.saved_entities: # Pick the first relevant saved entity and construct a test entity = self.saved_entities[path][0] params = {resource[1]['params'][0].name: entity} url_param = resource[0].format(**params) url = "{}{}".format(self.apis[api]["url"].rstrip("/"), url_param) test = Test("{} /x-nmos/{}/{}{}".format(resource[1]['method'].upper(), api, self.apis[api]["version"], url_param), self.auto_test_name(api)) else: # There were no saved entities found, so we can't test this parameterised URL test = Test("{} /x-nmos/{}/{}{}".format(resource[1]['method'].upper(), api, self.apis[api]["version"], resource[0].rstrip("/")), self.auto_test_name(api)) return test.UNCLEAR("No resources found to perform this test") # Test general URLs with no parameters elif not resource[1]['params']: url = "{}{}".format(self.apis[api]["url"].rstrip("/"), resource[0].rstrip("/")) test = Test("{} /x-nmos/{}/{}{}".format(resource[1]['method'].upper(), api, self.apis[api]["version"], resource[0].rstrip("/")), self.auto_test_name(api)) else: return None headers = None cors_methods = None cors_headers = None if resource[1]['method'].upper() == "OPTIONS": cors_methods = self.apis[api]["spec"].get_methods(resource[0]) cors_methods.remove("OPTIONS") cors_headers = ["Content-Type"] # Check if one of the supported methods, and a header are permitted in other requests headers = self.prepare_CORS(cors_methods[0], cors_headers) valid, response = self.do_request(resource[1]['method'], url, headers=headers) if not valid: return test.FAIL(response) if response.status_code != response_code: return test.FAIL("Incorrect response code: {}".format(response.status_code)) # Gather IDs of sub-resources for testing of parameterised URLs... self.save_subresources(resource[0], response) cors_valid, cors_message = self.check_CORS(resource[1]['method'], response.headers, cors_methods, cors_headers) if not cors_valid: # Fail immediately for CORS errors affecting any method return test.FAIL(cors_message) elif resource[1]['method'].upper() in ["HEAD", "OPTIONS"]: # For methods which don't return a payload, return immediately after the CORS header check return test.PASS() # For all other methods proceed to check the response against the schema schema = self.get_schema(api, resource[1]["method"], resource[0], response.status_code) if not schema: return test.MANUAL("Test suite unable to locate schema") valid, message = self.check_response(schema, resource[1]["method"], response) if valid: if message: return test.WARNING(message) else: return test.PASS() else: return test.FAIL(message) def save_subresources(self, path, response): """Get IDs contained within an array JSON response such that they can be interrogated individually""" subresources = list() try: if isinstance(response.json(), list): for entry in response.json(): # In general, lists return fully fledged objects which each have an ID if isinstance(entry, dict) and "id" in entry: subresources.append(entry["id"]) # In some cases lists contain strings which indicate the path to each resource elif isinstance(entry, str) and entry.endswith("/"): res_id = entry.rstrip("/") subresources.append(res_id) elif isinstance(response.json(), dict): for key, value in response.json().items(): # Cover the audio channel mapping spec case with dictionary keys if isinstance(key, str) and isinstance(value, dict): subresources.append(key) except json.JSONDecodeError: pass if len(subresources) > 0: if path not in self.saved_entities: self.saved_entities[path] = subresources else: self.saved_entities[path] += subresources def get_schema(self, api_name, method, path, status_code): return self.apis[api_name]["spec"].get_schema(method, path, status_code)