""" Tests for the utils module """ import datetime from math import ceil import os import pathlib import unittest from unittest.mock import patch import ddt from django.core.exceptions import ImproperlyConfigured from django.template import RequestContext from django.test import ( override_settings, RequestFactory, ) import pytz from rest_framework import status from rest_framework.exceptions import ValidationError from courses.factories import CourseRunFactory from ecommerce.factories import ( ReceiptFactory, ) from ecommerce.models import Order from financialaid.factories import ( FinancialAidFactory, ) from micromasters.exceptions import PossiblyImproperlyConfigured from micromasters.utils import ( as_datetime, chunks, custom_exception_handler, dict_with_keys, first_matching_item, get_field_names, is_near_now, is_subset_dict, now_in_utc, remove_falsey_values, safely_remove_file, serialize_model_object, pop_keys_from_dict, pop_matching_keys_from_dict, generate_md5, ) from search.base import MockedESTestCase @ddt.ddt class ExceptionHandlerTest(MockedESTestCase): """ Tests for the custom_exception_handler function.\ This is a Django Rest framework custom exception handler """ @classmethod def setUpTestData(cls): super().setUpTestData() cls.request = RequestFactory() cls.context = RequestContext(cls.request) @patch('sentry_sdk.capture_exception', autospec=True) def test_validation_error(self, mock_sentry): """ Test a standard exception handled by default by the rest framework """ exp = ValidationError('validation error') resp = custom_exception_handler(exp, self.context) assert resp.status_code == status.HTTP_400_BAD_REQUEST assert resp.data == ['validation error'] assert mock_sentry.called is False @patch('sentry_sdk.capture_exception', autospec=True) @ddt.data( ImproperlyConfigured, PossiblyImproperlyConfigured, ) def test_improperly_configured(self, exception_to_raise, mock_sentry): """ Test a standard exception not handled by default by the rest framework """ exp = exception_to_raise('improperly configured') resp = custom_exception_handler(exp, self.context) assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert resp.data == ['{0}: improperly configured'.format(exception_to_raise.__name__)] mock_sentry.assert_called_once_with() @patch('sentry_sdk.capture_exception', autospec=True) def test_index_error(self, mock_sentry): """ Test a other kind of exceptions are not handled """ exp = IndexError('index error') resp = custom_exception_handler(exp, self.context) assert resp is None assert mock_sentry.called is False def format_as_iso8601(time): """Helper function to format datetime with the Z at the end""" # Can't use datetime.isoformat() because format is slightly different from this iso_format = '%Y-%m-%dT%H:%M:%S' formatted_time = time.strftime(iso_format) if time.microsecond: miniseconds_format = '.%f' formatted_time += time.strftime(miniseconds_format)[:4] return formatted_time + "Z" class SerializerTests(MockedESTestCase): """ Tests for serialize_model """ def test_jsonfield(self): """ Test a model with a JSONField is handled correctly """ with override_settings(CYBERSOURCE_SECURITY_KEY='asdf'): receipt = ReceiptFactory.create() assert serialize_model_object(receipt) == { 'created_at': format_as_iso8601(receipt.created_at), 'data': receipt.data, 'id': receipt.id, 'modified_at': format_as_iso8601(receipt.modified_at), 'order': receipt.order.id, } def test_datetime(self): """ Test that a model with a datetime and date field is handled correctly """ financial_aid = FinancialAidFactory.create(justification=None) assert serialize_model_object(financial_aid) == { 'country_of_income': financial_aid.country_of_income, 'country_of_residence': financial_aid.country_of_residence, 'created_on': format_as_iso8601(financial_aid.created_on), 'date_documents_sent': financial_aid.date_documents_sent.isoformat(), 'date_exchange_rate': format_as_iso8601(financial_aid.date_exchange_rate), 'id': financial_aid.id, 'income_usd': financial_aid.income_usd, 'justification': None, 'original_currency': financial_aid.original_currency, 'original_income': financial_aid.original_income, 'status': financial_aid.status, 'tier_program': financial_aid.tier_program.id, 'updated_on': format_as_iso8601(financial_aid.updated_on), 'user': financial_aid.user.id, } def test_decimal(self): """ Test that a model with a decimal field is handled correctly """ course_run = CourseRunFactory.create() assert serialize_model_object(course_run) == { 'course': course_run.course.id, 'edx_course_key': course_run.edx_course_key, 'end_date': format_as_iso8601(course_run.end_date), 'enrollment_end': format_as_iso8601(course_run.enrollment_end), 'enrollment_start': format_as_iso8601(course_run.enrollment_start), 'enrollment_url': course_run.enrollment_url, 'freeze_grade_date': format_as_iso8601(course_run.freeze_grade_date), 'fuzzy_enrollment_start_date': course_run.fuzzy_enrollment_start_date, 'fuzzy_start_date': course_run.fuzzy_start_date, 'id': course_run.id, 'prerequisites': course_run.prerequisites, 'start_date': format_as_iso8601(course_run.start_date), 'title': course_run.title, 'upgrade_deadline': format_as_iso8601(course_run.upgrade_deadline), } class FieldNamesTests(unittest.TestCase): """ Tests for get_field_names """ def test_get_field_names(self): """ Assert that get_field_names does not include related fields """ assert set(get_field_names(Order)) == { 'user', 'status', 'total_price_paid', 'created_at', 'modified_at', } class UtilTests(unittest.TestCase): """ Tests for assorted utility functions """ def test_first_matching_item(self): """ Tests that first_matching_item returns a matching item in an iterable, or None """ iterable = [1, 2, 3, 4, 5] first_matching = first_matching_item(iterable, lambda item: item == 1) second_matching = first_matching_item(iterable, lambda item: item == 5) third_matching = first_matching_item(iterable, lambda item: item == 10) assert first_matching == 1 assert second_matching == 5 assert third_matching is None def test_remove_falsey_values(self): """ Tests that remove_falsey_values returns a generator that yields only truthy values from an iterable """ iterable = [1, 2, 'truthy', True, False, 0, ''] truthy_iterable = remove_falsey_values(iterable) assert list(truthy_iterable) == [1, 2, 'truthy', True] def test_is_subset_dict(self): """ Tests that is_subset_dict properly determines whether or not a dict is a subset of another dict """ d1 = {'a': 1, 'b': 2, 'c': {'d': 3}} d2 = {'a': 1, 'b': 2, 'c': {'d': 3}, 'e': 4} assert is_subset_dict(d1, d2) assert is_subset_dict(d1, d1) assert not is_subset_dict(d2, d1) new_dict = dict(d1) new_dict['f'] = 5 assert not is_subset_dict(new_dict, d2) new_dict = dict(d1) new_dict['a'] = 2 assert not is_subset_dict(new_dict, d2) new_dict = dict(d1) new_dict['c']['d'] = 123 assert not is_subset_dict(new_dict, d2) def test_is_near_now(self): """ Test is_near_now for now """ now = datetime.datetime.now(tz=pytz.UTC) assert is_near_now(now) is True later = now + datetime.timedelta(0, 6) assert is_near_now(later) is False earlier = now - datetime.timedelta(0, 6) assert is_near_now(earlier) is False def test_chunks(self): """ test for chunks """ input_list = list(range(113)) output_list = [] for nums in chunks(input_list): output_list += nums assert output_list == input_list output_list = [] for nums in chunks(input_list, chunk_size=1): output_list += nums assert output_list == input_list output_list = [] for nums in chunks(input_list, chunk_size=124): output_list += nums assert output_list == input_list def test_chunks_iterable(self): """ test that chunks works on non-list iterables too """ count = 113 input_range = range(count) chunk_output = [] for chunk in chunks(input_range, chunk_size=10): chunk_output.append(chunk) assert len(chunk_output) == ceil(113/10) range_list = [] for chunk in chunk_output: range_list += chunk assert range_list == list(range(count)) def test_safely_remove_file(self): """test for safely_remove_file""" # shouldn't error if the file already got removed (or never existed) safely_remove_file('/tmp/unlikely_to_exist') pathlib.Path('/tmp/test_file.txt').touch() assert os.path.exists('/tmp/test_file.txt') is True # removes the file safely_remove_file('/tmp/test_file.txt') assert os.path.exists('/tmp/test_file.txt') is False def test_dict_with_keys(self): """Tests that dict_with_keys correctly extracts the specified keys""" source_dict = {'a': 1, 'b': 2} assert dict_with_keys(source_dict, ['a']) == { 'a': 1, } assert dict_with_keys(source_dict, ['a', 'b']) == source_dict def test_as_datetime(): """as_datetime should convert a date to datetime at midnight, UTC""" a_while_ago = datetime.date(2016, 3, 4) assert as_datetime(a_while_ago) == datetime.datetime(2016, 3, 4, tzinfo=pytz.UTC) def test_now_in_utc(): """now_in_utc() should return the current time set to the UTC time zone""" now = now_in_utc() assert is_near_now(now) assert now.tzinfo == pytz.UTC def test_pop_keys_from_dict(): """pop_keys_from_dict should remove keys from a source dict and return a dict of removed key-values""" orig_dict = dict(a=1, b=2, c=3, d=4) new_dict = pop_keys_from_dict(orig_dict, ['a', 'd']) assert new_dict == dict(a=1, d=4) assert orig_dict == dict(b=2, c=3) new_dict = pop_keys_from_dict(orig_dict, ['non-existent key']) assert new_dict == {} assert orig_dict == dict(b=2, c=3) def test_pop_matching_keys_from_dict(): """ test_pop_matching_keys_from_dict should remove matching keys from a source dict and return a dict of removed key-values """ orig_dict = dict(a=1, b=2, c=3, d=4) new_dict = pop_matching_keys_from_dict(orig_dict, lambda k: k in ['a', 'd']) assert new_dict == dict(a=1, d=4) assert orig_dict == dict(b=2, c=3) new_dict = pop_matching_keys_from_dict(orig_dict, lambda k: k == 'non-existent key') assert new_dict == {} assert orig_dict == dict(b=2, c=3) def test_generate_md5(): """Test that generate_md5 generates an MD5 hash""" bytes_to_hash = 'abc'.encode('utf-8') md5_hash = generate_md5(bytes_to_hash) assert isinstance(md5_hash, str) assert len(md5_hash) == 32 repeat_md5_hash = generate_md5(bytes_to_hash) assert md5_hash == repeat_md5_hash