import datetime
import base64
import imghdr
import os
import copy

import django
from django.core.exceptions import ValidationError
from django.test import TestCase, override_settings
from mock import patch
import pytz
from rest_framework import serializers

from drf_extra_fields import compat
from drf_extra_fields.geo_fields import PointField
from drf_extra_fields.fields import (
    Base64ImageField,
    Base64FileField,
    DateRangeField,
    DateTimeRangeField,
    FloatRangeField,
    HybridImageField,
    IntegerRangeField,
    LowercaseEmailField,
)


import pytest

UNDETECTABLE_BY_IMGHDR_SAMPLE = """data:image/jpeg;base64,
/9j/2wBDAAMCAgICAgMCAgIDAwMDBAYEBAQEBAgGBgUGCQgKCgkICQkKDA8MCgsOCwkJDRENDg8QEBEQCgwSExIQEw8QEBD
/2wBDAQMDAwQDBAgEBAgQCwkLEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBD/wAARCAAQABADASIAAhEBAxEB
/8QAFgABAQEAAAAAAAAAAAAAAAAABwQF/8QAJBAAAQQBBAICAwAAAAAAAAAAAQIDBAYFBwgSExEiABQJMTL/xAAVAQEBAAAAAAAAAAAAAAAAAAAABv
/EACMRAAECBQMFAAAAAAAAAAAAAAECEQMEBQYhABIxFRZhgeH/2gAMAwEAAhEDEQA/ABSm0mobc8HmExLUlRzzEWPkJWW
+ulrsaUVAseUgslSlH9LKuPryIKuWPZdskzXmm3fX5m2nF4GlVxx/HOpx4ks51+MiU/Iaad7UcUo4tILoS4kqcWkezS0hO
/HvuRp0rO6hWnWO1UisZVuFi4GFeyEpmGepa5S5SWVPuciFKRFLgSrwetnyPIB+Vb4N9mKhQMzo5po9XLdDs9d6ZVix2VEhiL9kuNPxw2gEKcDQ
/rs8AuA8VAe0vdl7VOYn+27flGAUgmITjbhSmCg3BYlyeWDkMolvw4KOp1KM6iCNvngZHwetf//Z """


class UploadedBase64Image(object):
    def __init__(self, file=None, created=None):
        self.file = file
        self.created = created or datetime.datetime.now()


class UploadedBase64File(UploadedBase64Image):
    pass


class DownloadableBase64Image(object):
    class ImageFieldFile(object):
        def __init__(self, path):
            self.path = path

    def __init__(self, image_path):
        self.image = self.ImageFieldFile(path=image_path)


class DownloadableBase64File(object):
    class FieldFile(object):
        def __init__(self, path):
            self.path = path

    def __init__(self, file_path):
        self.file = self.FieldFile(path=file_path)


class UploadedBase64ImageSerializer(serializers.Serializer):
    file = Base64ImageField(required=False)
    created = serializers.DateTimeField()

    def update(self, instance, validated_data):
        instance.file = validated_data['file']
        return instance

    def create(self, validated_data):
        return UploadedBase64Image(**validated_data)


class DownloadableBase64ImageSerializer(serializers.Serializer):
    image = Base64ImageField(represent_in_base64=True)


class Base64ImageSerializerTests(TestCase):

    def test_create(self):
        """
        Test for creating Base64 image in the server side
        """
        now = datetime.datetime.now()
        file = 'R0lGODlhAQABAIAAAP///////yH5BAEKAAEALAAAAAABAAEAAAICTAEAOw=='
        serializer = UploadedBase64ImageSerializer(data={'created': now, 'file': file})
        uploaded_image = UploadedBase64Image(file=file, created=now)
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_image.created)
        self.assertFalse(serializer.validated_data is uploaded_image)

    def test_create_with_base64_prefix(self):
        """
        Test for creating Base64 image in the server side
        """
        now = datetime.datetime.now()
        file = ''
        serializer = UploadedBase64ImageSerializer(data={'created': now, 'file': file})
        uploaded_image = UploadedBase64Image(file=file, created=now)
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_image.created)
        self.assertFalse(serializer.validated_data is uploaded_image)

    def test_create_with_invalid_base64(self):
        """
        Test for creating Base64 image with an invalid Base64 string in the server side
        """
        now = datetime.datetime.now()
        file = 'this_is_not_a_base64'
        serializer = UploadedBase64ImageSerializer(data={'created': now, 'file': file})

        self.assertFalse(serializer.is_valid())
        self.assertEqual(serializer.errors, {'file': [Base64ImageField.INVALID_FILE_MESSAGE]})

    def test_validation_error_with_non_file(self):
        """
        Passing non-base64 should raise a validation error.
        """
        now = datetime.datetime.now()
        serializer = UploadedBase64ImageSerializer(data={'created': now,
                                                         'file': 'abc'})
        self.assertFalse(serializer.is_valid())
        self.assertEqual(serializer.errors, {'file': [Base64ImageField.INVALID_FILE_MESSAGE]})

    def test_remove_with_empty_string(self):
        """
        Passing empty string as data should cause image to be removed
        """
        now = datetime.datetime.now()
        file = 'R0lGODlhAQABAIAAAP///////yH5BAEKAAEALAAAAAABAAEAAAICTAEAOw=='
        uploaded_image = UploadedBase64Image(file=file, created=now)
        serializer = UploadedBase64ImageSerializer(instance=uploaded_image, data={'created': now, 'file': ''})
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_image.created)
        self.assertIsNone(serializer.validated_data['file'])

    def test_fallback_to_pil_if_not_detected_by_imghdr(self):
        """
        Passing a sample image which goes undetected by imghdr should
        still be detected by PIL.
        """
        now = datetime.datetime.now()
        file = UNDETECTABLE_BY_IMGHDR_SAMPLE

        # check image is undetectable by imghdr
        self.assertIsNone(imghdr.what('test.jpeg', base64.b64decode(file)))

        uploaded_image = UploadedBase64Image(file=file, created=now)
        serializer = UploadedBase64ImageSerializer(instance=uploaded_image, data={'created': now, 'file': file})
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_image.created)
        self.assertIsNotNone(serializer.validated_data['file'])

    def test_download(self):
        encoded_source = 'R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='

        with open('im.jpg', 'wb') as im_file:
            im_file.write(base64.b64decode(encoded_source))
        image = DownloadableBase64Image(os.path.abspath('im.jpg'))
        serializer = DownloadableBase64ImageSerializer(image)

        try:
            self.assertEqual(serializer.data['image'], encoded_source)
        finally:
            os.remove('im.jpg')

    def test_hybrid_image_field(self):
        field = HybridImageField()
        with patch('drf_extra_fields.fields.Base64FieldMixin') as mixin_patch:
            field.to_internal_value({})
            self.assertTrue(mixin_patch.to_internal_value.called)

        with patch('drf_extra_fields.fields.Base64FieldMixin') as mixin_patch:
            mixin_patch.to_internal_value.side_effect = ValidationError('foobar')
            with patch('drf_extra_fields.fields.ImageField') as image_patch:
                field.to_internal_value({})
                self.assertTrue(mixin_patch.to_internal_value.called)
                self.assertTrue(image_patch.to_internal_value.called)


class PDFBase64FileField(Base64FileField):
    ALLOWED_TYPES = ('pdf',)

    def get_file_extension(self, filename, decoded_file):
        return 'pdf'


class UploadedBase64FileSerializer(serializers.Serializer):
    file = PDFBase64FileField(required=False)
    created = serializers.DateTimeField()

    def update(self, instance, validated_data):
        instance.file = validated_data['file']
        return instance

    def create(self, validated_data):
        return UploadedBase64File(**validated_data)


class DownloadableBase64FileSerializer(serializers.Serializer):
    file = PDFBase64FileField(represent_in_base64=True)


class Base64FileSerializerTests(TestCase):
    def test_create(self):
        """
        Test for creating Base64 file in the server side
        """
        now = datetime.datetime.now()
        file = 'R0lGODlhAQABAIAAAP///////yH5BAEKAAEALAAAAAABAAEAAAICTAEAOw=='
        serializer = UploadedBase64FileSerializer(data={'created': now, 'file': file})
        uploaded_file = UploadedBase64File(file=file, created=now)
        serializer.is_valid()
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_file.created)
        self.assertFalse(serializer.validated_data is uploaded_file)

    def test_create_with_base64_prefix(self):
        """
        Test for creating Base64 file in the server side
        """
        now = datetime.datetime.now()
        file = ''
        serializer = UploadedBase64FileSerializer(data={'created': now, 'file': file})
        uploaded_file = UploadedBase64File(file=file, created=now)
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_file.created)
        self.assertFalse(serializer.validated_data is uploaded_file)

    def test_validation_error_with_non_file(self):
        """
        Passing non-base64 should raise a validation error.
        """
        now = datetime.datetime.now()
        serializer = UploadedBase64FileSerializer(data={'created': now, 'file': 'abc'})
        self.assertFalse(serializer.is_valid())
        self.assertEqual(serializer.errors, {'file': [Base64FileField.INVALID_FILE_MESSAGE]})

    def test_remove_with_empty_string(self):
        """
        Passing empty string as data should cause file to be removed
        """
        now = datetime.datetime.now()
        file = 'R0lGODlhAQABAIAAAP///////yH5BAEKAAEALAAAAAABAAEAAAICTAEAOw=='
        uploaded_file = UploadedBase64File(file=file, created=now)
        serializer = UploadedBase64FileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], uploaded_file.created)
        self.assertIsNone(serializer.validated_data['file'])

    def test_download(self):
        encoded_source = 'R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='

        with open('im.jpg', 'wb') as im_file:
            im_file.write(base64.b64decode(encoded_source))
        file = DownloadableBase64File(os.path.abspath('im.jpg'))
        serializer = DownloadableBase64FileSerializer(file)

        try:
            self.assertEqual(serializer.data['file'], encoded_source)
        finally:
            os.remove('im.jpg')


class SavePoint(object):
    def __init__(self, point=None, created=None):
        self.point = point
        self.created = created or datetime.datetime.now()


class PointSerializer(serializers.Serializer):
    point = PointField(required=False)
    created = serializers.DateTimeField()

    def update(self, instance, validated_data):
        instance.point = validated_data['point']
        return instance

    def create(self, validated_data):
        return SavePoint(**validated_data)


class StringPointSerializer(PointSerializer):
    point = PointField(required=False, str_points=True)


class PointSerializerTest(TestCase):

    def test_create(self):
        """
        Test for creating Point field in the server side
        """
        now = datetime.datetime.now()
        point = {
            "latitude": 49.8782482189424,
            "longitude": 24.452545489
        }
        serializer = PointSerializer(data={'created': now, 'point': point})
        saved_point = SavePoint(point=point, created=now)
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], saved_point.created)
        self.assertFalse(serializer.validated_data is saved_point)

    def test_validation_error_with_non_file(self):
        """
        Passing non-dict contains latitude and longitude should raise a validation error.
        """
        now = datetime.datetime.now()
        serializer = PointSerializer(data={'created': now, 'point': '123'})
        self.assertFalse(serializer.is_valid())

    def test_remove_with_empty_string(self):
        """
        Passing empty string as data should cause point to be removed
        """
        now = datetime.datetime.now()
        point = {
            "latitude": 49.8782482189424,
            "longitude": 24.452545489
        }
        saved_point = SavePoint(point=point, created=now)
        serializer = PointSerializer(data={'created': now, 'point': ''})
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['created'], saved_point.created)
        self.assertIsNone(serializer.validated_data['point'])

    def test_empty_latitude(self):
        now = datetime.datetime.now()
        point = {
            "latitude": 49.8782482189424,
            "longitude": ""
        }
        serializer = PointSerializer(data={'created': now, 'point': point})
        self.assertFalse(serializer.is_valid())

    def test_invalid_latitude(self):
        now = datetime.datetime.now()
        point = {
            "latitude": 49.8782482189424,
            "longitude": "fdff"
        }
        serializer = PointSerializer(data={'created': now, 'point': point})
        self.assertFalse(serializer.is_valid())

    def test_serialization(self):
        """
        Regular JSON serialization should output float values
        """
        from django.contrib.gis.geos import Point
        now = datetime.datetime.now()
        point = Point(24.452545489, 49.8782482189424)
        saved_point = SavePoint(point=point, created=now)
        serializer = PointSerializer(saved_point)
        self.assertEqual(serializer.data['point'], {'latitude': 49.8782482189424, 'longitude': 24.452545489})

    def test_str_points_serialization(self):
        """
        PointField with str_points=True should output string values
        """
        from django.contrib.gis.geos import Point
        now = datetime.datetime.now()
        # test input is shortened due to string conversion rounding
        # gps has at max 8 decimals, so it doesn't make a difference
        point = Point(24.452545489, 49.8782482189)
        saved_point = SavePoint(point=point, created=now)
        serializer = StringPointSerializer(saved_point)
        self.assertEqual(serializer.data['point'], {'latitude': '49.8782482189', 'longitude': '24.452545489'})


# Backported from django_rest_framework/tests/test_fields.py
def get_items(mapping_or_list_of_two_tuples):
    # Tests accept either lists of two tuples, or dictionaries.
    if isinstance(mapping_or_list_of_two_tuples, dict):
        # {value: expected}
        return mapping_or_list_of_two_tuples.items()
    # [(value, expected), ...]
    return mapping_or_list_of_two_tuples


class IntegerRangeSerializer(serializers.Serializer):

    range = IntegerRangeField()


class FloatRangeSerializer(serializers.Serializer):

    range = FloatRangeField()


class DateTimeRangeSerializer(serializers.Serializer):

    range = DateTimeRangeField()


class DateRangeSerializer(serializers.Serializer):

    range = DateRangeField()


class FieldValues:
    """
    Base class for testing valid and invalid input values.
    """
    def test_valid_inputs(self):
        """
        Ensure that valid values return the expected validated data.
        """
        for input_value, expected_output in get_items(self.valid_inputs):
            initial_input_value = copy.deepcopy(input_value)

            serializer = self.serializer_class(data=input_value)
            serializer.is_valid()

            assert serializer.initial_data == initial_input_value
            assert self.field.run_validation(initial_input_value) == expected_output

    def test_invalid_inputs(self):
        """
        Ensure that invalid values raise the expected validation error.
        """
        for input_value, expected_failure in get_items(self.invalid_inputs):
            with pytest.raises(serializers.ValidationError) as exc_info:
                self.field.run_validation(input_value)
            assert exc_info.value.detail == expected_failure

    def test_outputs(self):
        for output_value, expected_output in get_items(self.outputs):
            assert self.field.to_representation(output_value) == expected_output

# end of backport


@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None,
                    reason='RangeField is only available for django1.8+'
                    ' and with psycopg2.')
class TestIntegerRangeField(FieldValues):
    """
    Values for `ListField` with CharField as child.
    """
    serializer_class = IntegerRangeSerializer

    if compat.NumericRange is not None:
        valid_inputs = [
            ({'lower': '1', 'upper': 2, 'bounds': '[)'},
             compat.NumericRange(**{'lower': 1, 'upper': 2, 'bounds': '[)'})),
            ({'lower': 1, 'upper': 2},
             compat.NumericRange(**{'lower': 1, 'upper': 2})),
            ({'lower': 1},
             compat.NumericRange(**{'lower': 1})),
            ({'upper': 1},
             compat.NumericRange(**{'upper': 1})),
            ({'empty': True},
             compat.NumericRange(**{'empty': True})),
            ({}, compat.NumericRange()),
        ]
        invalid_inputs = [
            ({'lower': 'a'}, ['A valid integer is required.']),
            ('not a dict', ['Expected a dictionary of items but got type "str".']),
            ({'foo': 'bar'}, ['Extra content not allowed "foo".']),
        ]
        outputs = [
            (compat.NumericRange(**{'lower': '1', 'upper': '2'}),
             {'lower': 1, 'upper': 2, 'bounds': '[)'}),
            (compat.NumericRange(**{'empty': True}), {'empty': True}),
            (compat.NumericRange(), {'bounds': '[)', 'lower': None, 'upper': None}),
        ]
    field = IntegerRangeField()

    def test_no_source_on_child(self):
        with pytest.raises(AssertionError) as exc_info:
            IntegerRangeField(child=serializers.IntegerField(source='other'))

        assert str(exc_info.value) == (
            "The `source` argument is not meaningful when applied to a `child=` field. "
            "Remove `source=` from the field declaration."
        )


@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None,
                    reason='RangeField is only available for django1.8+'
                    ' and with psycopg2.')
class TestFloatRangeField(FieldValues):
    """
    Values for `ListField` with CharField as child.
    """
    serializer_class = FloatRangeSerializer

    if compat.NumericRange is not None:
        valid_inputs = [
            ({'lower': '1', 'upper': 2., 'bounds': '[)'},
             compat.NumericRange(**{'lower': 1., 'upper': 2., 'bounds': '[)'})),
            ({'lower': 1., 'upper': 2.},
             compat.NumericRange(**{'lower': 1, 'upper': 2})),
            ({'lower': 1},
             compat.NumericRange(**{'lower': 1})),
            ({'upper': 1},
             compat.NumericRange(**{'upper': 1})),
            ({'empty': True},
             compat.NumericRange(**{'empty': True})),
            ({}, compat.NumericRange()),
        ]
        invalid_inputs = [
            ({'lower': 'a'}, ['A valid number is required.']),
            ('not a dict', ['Expected a dictionary of items but got type "str".']),
        ]
        outputs = [
            (compat.NumericRange(**{'lower': '1.1', 'upper': '2'}),
             {'lower': 1.1, 'upper': 2, 'bounds': '[)'}),
            (compat.NumericRange(**{'empty': True}), {'empty': True}),
            (compat.NumericRange(), {'bounds': '[)', 'lower': None, 'upper': None}),
        ]
    field = FloatRangeField()

    def test_no_source_on_child(self):
        with pytest.raises(AssertionError) as exc_info:
            FloatRangeField(child=serializers.IntegerField(source='other'))

        assert str(exc_info.value) == (
            "The `source` argument is not meaningful when applied to a `child=` field. "
            "Remove `source=` from the field declaration."
        )


@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None,
                    reason='RangeField is only available for django1.8+'
                    ' and with psycopg2.')
@override_settings(USE_TZ=True)
class TestDateTimeRangeField(TestCase, FieldValues):
    """
    Values for `ListField` with CharField as child.
    """
    serializer_class = DateTimeRangeSerializer

    if compat.DateTimeTZRange is not None:
        valid_inputs = [
            ({'lower': '2001-01-01T13:00:00Z',
              'upper': '2001-02-02T13:00:00Z',
              'bounds': '[)'},
             compat.DateTimeTZRange(
                 **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=pytz.utc),
                    'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=pytz.utc),
                    'bounds': '[)'})),
            ({'upper': '2001-02-02T13:00:00Z',
              'bounds': '[)'},
             compat.DateTimeTZRange(
                 **{'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=pytz.utc),
                    'bounds': '[)'})),
            ({'lower': '2001-01-01T13:00:00Z',
              'bounds': '[)'},
             compat.DateTimeTZRange(
                 **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=pytz.utc),
                    'bounds': '[)'})),
            ({'empty': True},
             compat.DateTimeTZRange(**{'empty': True})),
            ({}, compat.DateTimeTZRange()),
        ]
        invalid_inputs = [
            ({'lower': 'a'}, ['Datetime has wrong format. Use one of these'
                              ' formats instead: '
                              'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].']),
            ('not a dict', ['Expected a dictionary of items but got type "str".']),
        ]
        outputs = [
            (compat.DateTimeTZRange(
                **{'lower': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=pytz.utc),
                   'upper': datetime.datetime(2001, 2, 2, 13, 00, tzinfo=pytz.utc)}),
                {'lower': '2001-01-01T13:00:00Z',
                 'upper': '2001-02-02T13:00:00Z',
                 'bounds': '[)'}),
            (compat.DateTimeTZRange(**{'empty': True}),
             {'empty': True}),
            (compat.DateTimeTZRange(),
             {'bounds': '[)', 'lower': None, 'upper': None}),
        ]
    field = DateTimeRangeField()

    def test_no_source_on_child(self):
        with pytest.raises(AssertionError) as exc_info:
            DateTimeRangeField(child=serializers.IntegerField(source='other'))

        assert str(exc_info.value) == (
            "The `source` argument is not meaningful when applied to a `child=` field. "
            "Remove `source=` from the field declaration."
        )


@pytest.mark.skipif(django.VERSION < (1, 8) or compat.postgres_fields is None,
                    reason='RangeField is only available for django1.8+'
                    ' and with psycopg2.')
class TestDateRangeField(FieldValues):
    """
    Values for `ListField` with CharField as child.
    """
    serializer_class = DateRangeSerializer

    if compat.DateRange is not None:
        valid_inputs = [
            ({'lower': '2001-01-01',
              'upper': '2001-02-02',
              'bounds': '[)'},
             compat.DateRange(
                 **{'lower': datetime.date(2001, 1, 1),
                    'upper': datetime.date(2001, 2, 2),
                    'bounds': '[)'})),
            ({'upper': '2001-02-02',
              'bounds': '[)'},
             compat.DateRange(
                 **{'upper': datetime.date(2001, 2, 2),
                    'bounds': '[)'})),
            ({'lower': '2001-01-01',
              'bounds': '[)'},
             compat.DateRange(
                 **{'lower': datetime.date(2001, 1, 1),
                    'bounds': '[)'})),
            ({'empty': True},
             compat.DateRange(**{'empty': True})),
            ({}, compat.DateRange()),
        ]
        invalid_inputs = [
            ({'lower': 'a'}, ['Date has wrong format. Use one of these'
                              ' formats instead: '
                              'YYYY-MM-DD.']),
            ('not a dict', ['Expected a dictionary of items but got type "str".']),
        ]
        outputs = [
            (compat.DateRange(
                **{'lower': datetime.date(2001, 1, 1),
                   'upper': datetime.date(2001, 2, 2)}),
                {'lower': '2001-01-01',
                 'upper': '2001-02-02',
                 'bounds': '[)'}),
            (compat.DateRange(**{'empty': True}),
             {'empty': True}),
            (compat.DateRange(), {'bounds': '[)', 'lower': None, 'upper': None}),
        ]
    field = DateRangeField()

    def test_no_source_on_child(self):
        with pytest.raises(AssertionError) as exc_info:
            DateRangeField(child=serializers.IntegerField(source='other'))

        assert str(exc_info.value) == (
            "The `source` argument is not meaningful when applied to a `child=` field. "
            "Remove `source=` from the field declaration."
        )


class EmailSerializer(serializers.Serializer):
    email = LowercaseEmailField()


class LowercaseEmailFieldTest(TestCase):

    def test_serialization(self):
        email = 'ALL_CAPS@example.com'
        serializer = EmailSerializer(data={'email': email})
        serializer.is_valid()
        self.assertTrue(serializer.is_valid())
        self.assertEqual(serializer.validated_data['email'], email.lower())