import re
from bson import ObjectId
from bson.dbref import DBRef
from bson.errors import InvalidId
from django.utils.encoding import smart_str
from django.utils.datastructures import MultiValueDict
from rest_framework import fields
from rest_framework.exceptions import ValidationError


class DateTime000Field(fields.DateTimeField):
    """ discards microseconds """
    def to_internal_value(self, value):
        value = super().to_internal_value(value)
        return value.replace(microsecond=value.microsecond//1000*1000)


class ObjectIdField(fields.Field):
    type_label = 'ObjectIdField'

    def to_representation(self, value):
        return smart_str(value)

    def to_internal_value(self, data):
        try:
            return ObjectId(data)
        except InvalidId as e:
            raise ValidationError(e)


class ListField(fields.ListField):
    """ parses list of values under field_name
    like in ?foo=1&foo=2&foo=3 to [1,2,3]
    """
    def get_value(self, data):
        if isinstance(data, MultiValueDict):
            ret = data.getlist(self.field_name, fields.empty)
        elif isinstance(data, dict):
            ret = data.get(self.field_name, fields.empty)
        else:
            raise ValidationError("not a dict: " + str(type(data)))
        if ret == ['']:
            ret = fields.empty
        return ret

    def to_internal_value(self, data):
        if not hasattr(data, '__iter__'):
            raise ValidationError("not a list: " + str(type(data)))
        return [self.child.run_validation(item) for item in data]

class DictField(fields.DictField):
    """ parses dict of values under field_name-prefixed
    like in ?foo.bar=1&foo.baz=2 to { bar: 1, baz: 2 }
    """
    valid_keys = None
    required_keys = None

    def __init__(self, valid_keys=None, required_keys=None, **kwargs):
        if valid_keys:
            self.valid_keys = valid_keys
        if self.valid_keys is not None:
            self.valid_keys = set(self.valid_keys)

        if required_keys:
            self.required_keys = required_keys
        if self.required_keys is not None:
            self.required_keys = set(self.required_keys)

        super().__init__(**kwargs)

    def get_value(self, data):
        if isinstance(data, MultiValueDict):
            regex = re.compile(r"^%s\.(.*)$" % re.escape(self.field_name))
            ret = {}
            for name, value in data.items():
                match = regex.match(name)
                if not match:
                    continue
                key = match.groups()[0]
                if value != '':
                    ret[key] = value
        elif isinstance(data, dict):
            ret = data.get(self.field_name, fields.empty)
        else:
            raise ValidationError("not a dict: " + str(type(data)))

        if ret is fields.empty or len(ret) == 0:
            return fields.empty
        return ret

    def to_internal_value(self, data):
        if not hasattr(data, '__getitem__') or not hasattr(data, 'items'):
            raise ValidationError("not a dict: " + str(type(data)))

        keys = set(data.keys())
        if self.valid_keys is not None:
            if not keys <= self.valid_keys:
                raise ValidationError("invalid keys in dict: " + str(keys))

        if self.required_keys is not None:
            if not keys >= self.required_keys:
                raise ValidationError("missing required keys in dict: " + str(keys))

        return dict([
            (str(key), self.child.run_validation(value))
            for key, value in data.items()
        ])

class RangeField(DictField):
    valid_keys = ('min', 'max')

class GeoPointField(DictField):
    """ geo coordinates """
    valid_keys = ('lng', 'lat')
    required_keys = ('lng', 'lat')

    def __init__(self, **kwargs):
        kwargs['child'] = fields.FloatField()
        super().__init__(**kwargs)

    def to_internal_value(self, data):
        value = super().to_internal_value(data)
        return { 'type': 'Point', 'coordinates': [ value['lng'], value['lat'] ] }