import contextlib import datetime import os import socket import time from collections.abc import Callable from http import HTTPStatus, client from unittest import mock from hutils.shortcuts import str_to_datetime def disable_migration(): """ get disable migration """ class DisableMigration: def __contains__(self, item): return True def __getitem__(self, item): return None return DisableMigration() def disable_network(): """ Disable network """ class DisableNetwork: def __init__(self, *args, **kwargs): raise Exception("Network through socket is disabled!") def __call__(self, *args, **kwargs): raise Exception("Network through socket is disabled!") real_socket = socket.socket client.HTTPConnection = DisableNetwork try: from urllib3 import connection connection.HTTPConnection = DisableNetwork except ImportError: pass socket.socket = DisableNetwork patcher = mock.patch("asyncio.selector_events.socket.socket", real_socket) patcher.start() return patcher def disable_elastic_apm(): """ disable elastic apm """ os.environ["ELASTIC_APM_DISABLE_SEND"] = "true" os.environ["ELASTIC_APM_CENTRAL_CONFIG"] = "false" class MockDateTime(datetime.datetime): """ class for mocking datetime.datetime """ @classmethod def now(cls, tz=None): return cls.today() def __sub__(self, other): result = super(MockDateTime, self).__sub__(other) if hasattr(result, "timetuple"): return MockDateTime.fromtimestamp(time.mktime(result.timetuple())) return result class Mogician: """ class for mocking any time """ @staticmethod def mock_field_default(field): from django.db import DefaultConnectionProxy if field.has_default(): if callable(field.default): if field.default.__name__ == "now": return datetime.datetime.now() return field.default() return field.default if not field.empty_strings_allowed or ( field.null and not DefaultConnectionProxy().features.interprets_empty_strings_as_nulls ): return None return "" def __init__(self, fake_to): self.the_datetime = fake_to if isinstance(fake_to, datetime.datetime) else str_to_datetime(fake_to) self.patchers = [ mock.patch("datetime.datetime", MockDateTime), mock.patch("time.localtime", lambda: time.struct_time(self.the_datetime.timetuple())), mock.patch("time.time", lambda: time.mktime(self.the_datetime.timetuple())), ] try: import django # NOQA self.patchers.extend( [ mock.patch("django.db.models.fields.Field.get_default", Mogician.mock_field_default), mock.patch("django.utils.timezone.now", MockDateTime.now), ] ) except ImportError: pass def __enter__(self): for patcher in self.patchers: patcher.start() def __exit__(self, exc_type, exc_val, exc_tb): for patcher in self.patchers: patcher.stop() @contextlib.contextmanager def fake_time(fake_to): """ short cut for mocking time or datetime, supports django. Examples:: @fake_time('2018-08-08 12:00:00') def test_something_related_to_datetime(self): pass :type fake_to: str | datetime.datetime """ with Mogician(fake_to): yield class TestCaseMixin: """ 增加一些便于测试的小方法的 Mixin Examples: from rest_framework.test import APITestCase class TestCase(APITestCase, TestCaseMixin): pass class ExampleTest(TestCase): def test_something(self): response = self.client.get(url) self.ok(response) For details, see <tests.test_unittest.FuncTestCaseAPITests> """ def ok(self, response, *, is_201=False, is_204=False, **kwargs): """ shortcuts to response 20X """ expected = (is_201 and HTTPStatus.CREATED) or (is_204 and HTTPStatus.NO_CONTENT) or HTTPStatus.OK self.assertEqual( expected, response.status_code, "status code should be {}: {}".format(expected, getattr(response, "data", "")), ) if kwargs: self.assert_same(response.data, **kwargs) return self def bad_request(self, response, **kwargs): """ shortcuts to response 400 """ self.assertEqual(HTTPStatus.BAD_REQUEST, response.status_code, "status code should be 400") if kwargs: self.assert_same(response.data, **kwargs) return self def not_found(self, response): """ shortcuts to response 404 """ self.assertEqual(HTTPStatus.NOT_FOUND, response.status_code) return self def forbidden(self, response, **kwargs): """ shortcuts to response 403 """ self.assertEqual(HTTPStatus.FORBIDDEN, response.status_code, "status code should be 403") if kwargs: self.assert_same(response.data, **kwargs) return self def assert_increases(self, delta: int, func: Callable, name=""): """ shortcuts to verify func change is equal to delta """ test_case = self class Detector: def __init__(self): self.previous = None def __enter__(self): self.previous = func() def __exit__(self, exc_type, exc_val, exc_tb): if not exc_val: test_case.assertEqual( self.previous + delta, func(), "{} should change {}".format(name, delta).strip() ) return Detector() def assert_model_increases(self, *models, delta: int = 1, **lookups): """ shortcuts to verify value change """ stack = contextlib.ExitStack() for case in models: if isinstance(case, tuple): model, delta = case else: model, delta = case, 1 stack.enter_context(self.assert_increases(delta, model.all_objects.filter(**lookups).count, model.__name__)) return stack def assert_same(self, data, **expects): """ shortcuts to compare value (support nested dictionaries, lists and array length) """ def _get_key(_data, _key: str): """ get the expanded value """ _value = _data for part in _key.split("__"): if part == "length": _value = len(_value) elif part == "bool": _value = bool(_value) elif part.startswith("_"): try: _value = _value[int(part[1:])] except ValueError: _value = getattr(_value, part[1:]) else: try: _value = _value[int(part)] except ValueError: _value = _value[part] return _value for key, expect in expects.items(): actual = _get_key(data, key) try: self.assertEqual( expect, actual, "{} value not match.\nExpect: {} ({})\nActual: {} ({})".format( key, expect, type(expect), actual, type(actual) ), ) except Exception: print("\nAssertionError:") print("Actual: {}".format(data)) print("Expect: {}".format(expects)) raise return self def assert_data(self, expected_data, actual_data): """ shortcuts to compare data (expected_data can be subset of actual_data) """ if isinstance(expected_data, list): data = list(actual_data) self.assertEqual(len(expected_data), len(data)) for index, item in enumerate(expected_data): self.assert_data(item, data[index]) elif isinstance(expected_data, dict): for k, v in expected_data.items(): self.assertTrue(k in actual_data, msg="{} not in actual_data".format(k)) self.assert_data(v, actual_data[k]) else: self.assertEqual(expected_data, actual_data) return self