import decimal
import random
from collections import defaultdict
from datetime import datetime
from multiprocessing import Pool
from uuid import uuid4

import six
from anonymizer import replacers
from django.conf import settings
from django.db import connection, transaction
from django.utils.timezone import get_default_timezone
from faker import Faker

from six.moves import xrange

randrange = random.SystemRandom().randrange

alphanumeric = ""
for i in range(ord('A'), ord('Z')+1):
    alphanumeric += chr(i)
for i in range(ord('a'), ord('z')+1):
    alphanumeric += chr(i)
for i in range(ord('0'), ord('9')+1):
    alphanumeric += chr(i)

general_chars = alphanumeric + " _-"


class DjangoFaker(object):
    """
    Class that provides fake data, using Django specific knowledge to ensure
    acceptable data for Django models.
    """
    faker = Faker()

    def __init__(self):
        self.init_values = {}
        self.unique_suffixes = defaultdict(int)

    def _prep_init(self, field):
        if field in self.init_values:
            return

        field_vals = set(field.model._default_manager.values_list(field.name, flat=True).iterator())
        self.init_values[field] = field_vals

    def get_allowed_value(self, source, field):
        retval = source()
        if field is None:
            return retval

        # Enforce unique. Ensure we don't set the same values, as either
        # any of the existing values, or any of the new ones we make up.
        unique = getattr(field, 'unique', None)
        if unique:
            self._prep_init(field)
            used = self.init_values[field]
            for i in xrange(0, 20):
                if retval in used:
                    retval = source()
                else:
                    break

            if retval in used:
                raise Exception("Cannot generate unique data for field %s. Last value tried %s" % (field, retval))
            used.add(retval)

        # Enforce max_length
        max_length = getattr(field, 'max_length', None)
        if max_length is not None:
            retval = retval[:max_length]

        return retval

    def uuid(self, field=None):
        # bypass chopping from max_length
        return str(uuid4())

    def varchar(self, field=None):
        """
        Returns a chunk of text, of maximum length 'max_length'
        """
        assert field is not None, "The field parameter must be passed to the 'varchar' method."
        max_length = field.max_length

        def source():
            length = random.choice(range(1, max_length + 1))
            return "".join(random.choice(general_chars) for i in xrange(length))
        return self.get_allowed_value(source, field)

    def simple_pattern(self, pattern, field=None):
        """
        Use a simple pattern to make the field - # is replaced with a random number,
        ? with a random letter.
        """
        return self.get_allowed_value(lambda: self.faker.bothify(pattern), field)

    def bool(self, field=None):
        """
        Returns a random boolean
        """
        return self.get_allowed_value(lambda: bool(randrange(0, 2)), field)

    def integer(self, field=None):
        return self.get_allowed_value(lambda: random.randint(-1000000, 1000000), field)

    def positive_integer(self, field=None):
        return self.get_allowed_value(lambda: random.randint(0, 1000000), field)

    def small_integer(self, field=None):
        return self.get_allowed_value(lambda: random.randint(-32768, 32767), field)

    def positive_small_integer(self, field=None):
        return self.get_allowed_value(lambda: random.randint(0, 32767), field)

    def datetime(self, field=None, val=None):
        """
        Returns a random datetime. If 'val' is passed, a datetime within two
        years of that date will be returned.
        """
        if val is None:
            def source():
                tzinfo = get_default_timezone() if settings.USE_TZ else None
                return datetime.fromtimestamp(randrange(1, 2100000000),
                                              tzinfo)
        else:
            def source():
                tzinfo = get_default_timezone() if settings.USE_TZ else None
                return datetime.fromtimestamp(int(val.strftime("%s")) +
                                              randrange(-365*24*3600*2, 365*24*3600*2),
                                              tzinfo)
        return self.get_allowed_value(source, field)

    def date(self, field=None, val=None):
        """
        Like datetime, but truncated to be a date only
        """
        return self.datetime(field=field, val=val).date()

    def decimal(self, field=None, val=None):
        def source():
            return decimal.Decimal(random.randrange(0, 100000))/(10**field.decimal_places)
        return self.get_allowed_value(source, field)

    def postcode(self, field=None):
        return self.get_allowed_value(self.faker.postcode, field)

    def country(self, field=None):
        return self.get_allowed_value(self.faker.country, field)

    def lorem(self, field=None, val=None):
        """
        Returns lorem ipsum text. If val is provided, the lorem ipsum text will
        be the same length as the original text, and with the same pattern of
        line breaks.
        """
        if val == '':
            return ''

        if val is not None:
            def generate(length):
                # Get lorem ipsum of a specific length.
                collect = ""
                while len(collect) < length:
                    collect += ' %s' % self.faker.sentence()
                collect = collect[:length]
                return collect

            # We want to match the pattern of the text - linebreaks
            # in the same places.
            def source():
                parts = val.split("\n")
                for i, p in enumerate(parts):
                    # Replace each bit with lorem ipsum of the same length
                    parts[i] = generate(len(p))
                return "\n".join(parts)
        else:
            def source():
                return ' '.join(self.faker.sentences())
        return self.get_allowed_value(source, field)

    def unique_lorem(self, field=None, val=None):
        """
        Returns lorem ipsum text guaranteed to be unique. First uses lorem function
        then adds a unique integer suffix.
        """
        lorem_text = self.lorem(field, val)
        max_length = getattr(field, 'max_length', None)

        suffix_str = str(self.unique_suffixes[field])
        unique_text = lorem_text + suffix_str
        if max_length is not None:
            # take the last max_length chars
            unique_text = unique_text[-max_length:]
        self.unique_suffixes[field] += 1
        return unique_text

    def choice(self, field=None):
        assert field is not None, "The field parameter must be passed to the 'choice' method."
        choices = [c[0] for c in field.choices]
        return self.get_allowed_value(lambda: random.choice(choices), field)

    # Other attributes provided by 'Faker':
    # user_name
    # first_name
    # last_name
    # name
    # email
    # address
    # phonenumber
    # street_address
    # city
    # state
    # zip_code
    # company

    def __getattr__(self, name):
        # we delegate most calls to faker, but add checks
        source = getattr(self.faker, name)

        def func(*args, **kwargs):
            field = kwargs.get('field', None)
            return self.get_allowed_value(source, field)
        return func


class Anonymizer(object):
    """
    Base class for all anonymizers. When executed with the ``run()`` method,
    it will anonymize the data for a specific model.
    """

    model = None

    # attributes is a dictionary of {attribute_name: replacer}, where replacer is
    # a callable that takes as arguments this Anonymizer instance, the object to
    # be altered, the field to be altered, and the current field value, and
    # returns a replacement value.

    # This signature is designed to be useful for making lambdas that call the
    # 'faker' instance provided on this class, but it can be used with any
    # function.
    attributes = None

    # To impose an order on Anonymizers within a module, this can be set - lower
    # values are done first.
    order = 0

    faker = DjangoFaker()

    def __init__(self):
        super(Anonymizer, self).__init__()

        assert self.attributes is not None, '"attributes" attribute must be set'
        assert self.model is not None, '"model" attribute must be set'

        self.replacers = []
        for attname, replacer in self.attributes:
            if replacer == 'SKIP':
                continue

            if isinstance(replacer, six.string_types):
                # 'email' is shortcut for: replacers.email
                replacer = getattr(replacers, replacer)
            elif not callable(replacer):
                raise Exception("Expected callable or string to be passed, got %r." % replacer)

            field = self.model._meta.get_field(attname)
            self.replacers.append((attname, field, replacer))

    def get_queryset(self):
        """
        Returns the QuerySet to be manipulated
        """
        return (self.model._default_manager.get_queryset()
                                           .select_related(None)
                                           .order_by('pk'))

    def get_queryset_chunk_iterator(self, chunksize):
        queryset = self.get_queryset()
        num_rows = queryset.count()

        index = 0
        while index < num_rows:
            yield queryset[index:index + chunksize]
            index += chunksize

    def alter_object(self, obj):
        """
        Alters all the attributes in an individual object.

        If it returns False, the object will not be saved
        """
        for attname, field, replacer in self.replacers:
            currentval = getattr(obj, attname)
            replacement = replacer(self, obj, field, currentval)
            setattr(obj, attname, replacement)

    def run(self, chunksize=2000, parallel=4):
        self.validate()

        if not self.replacers:
            return

        chunks = self.get_queryset_chunk_iterator(chunksize)

        if parallel == 0:
            for objs in chunks:
                _run(self, objs)
        else:
            connection.close()
            pool = Pool(processes=parallel)
            futures = [pool.apply_async(_run, (self, objs))
                       for objs in chunks]
            for future in futures:
                future.get()
            pool.close()
            pool.join()

    def validate(self):
        model_attrs = set(f.attname for f in self.model._meta.fields)
        given_attrs = set(name for name, replacer in self.attributes)
        if model_attrs != given_attrs:
            msg = ""
            missing_attrs = model_attrs - given_attrs
            if missing_attrs:
                msg += "The following fields are missing: %s. " % ", ".join(missing_attrs)
                msg += "Add the replacer \"SKIP\" to skip these fields."
            extra_attrs = given_attrs - model_attrs
            if extra_attrs:
                msg += "The following non-existent fields were supplied: %s." % ", ".join(extra_attrs)
            raise ValueError("The attributes list for %s does not match the complete list of fields for that model. %s" % (self.model.__name__, msg))

    def create_query(self, replacer_attrs):
        return 'UPDATE %s SET %s WHERE %s = %%s' % (
            self.model._meta.db_table,
            ', '.join('%s = %%s' % attr for attr in replacer_attrs),
            self.model._meta.pk.column)

    def create_query_args(self, updates, replacer_attrs):
        pk_field = self.model._meta.pk
        fields = {attr: self.model._meta.get_field(attr) for attr in replacer_attrs}

        all_args = []
        for k, v in six.iteritems(updates):
            args = [fields[attr].get_db_prep_value(v[attr], connection)
                    for attr in replacer_attrs]

            # pk is always the last argument in this query
            args.append(pk_field.get_db_prep_value(k, connection))
            all_args.append(tuple(args))

        return all_args


def _run(anonymizer, objs):
    values = {}
    replacer_attr = tuple(r[0] for r in anonymizer.replacers)
    for obj in objs.iterator():
        retval = anonymizer.alter_object(obj)
        if retval is False:
            continue

        values[obj.pk] = {attname: getattr(obj, attname) for attname in replacer_attr}

    query = anonymizer.create_query(replacer_attr)
    query_args = anonymizer.create_query_args(values, replacer_attr)

    with transaction.atomic():
        with connection.cursor() as cursor:
            if connection.vendor == 'postgresql':
                cursor.execute('SET CONSTRAINTS ALL DEFERRED')
            cursor.executemany(query, query_args)