import datetime
import os
import re
import yaml

from collections import defaultdict, namedtuple
from faker import Faker
from jinja2 import Environment, PackageLoader
from sqlalchemy.ext.associationproxy import AssociationProxy
from backend.database import Date, DateTime

from import parse_datetime, utcnow

identifier_re = re.compile('(?P<class_name>\w+)\((?P<identifiers>[\w,\s]+)\)')
Identifier = namedtuple('Identifier', 'class_name id')

class AttrGetter:
    def __init__(self, dict_):
        self.dict_ = dict_

    def __getattr__(self, name):
        return self.dict_[name]

class ModelFactory:
    env = None
    loaded_class_names = set()
    class_name_lookup = {}
    model_fixtures = {}  # raw data from yaml files

    def __init__(self, session, model_classes, fixtures_dir):
        self.session = session
        self.model_classes = model_classes
        self.fixtures_dir = fixtures_dir
        self.model_instances = {}

    def get_models(self, identifiers):
        return AttrGetter(self.create_all(identifiers))

    def create_all(self, identifiers):
        identifiers = flatten_identifiers(identifiers)
        models = { self._create(identifier)
                  for identifier in identifiers}
        return models

    def create(self, class_name_or_identifier, identifier=None):
        if identifier:
            identifier = Identifier(class_name_or_identifier, identifier)
            identifier = _convert_str(class_name_or_identifier)[0]
        model = self._create(identifier)
        return model

    def _create(self, identifier):
        if not identifier.class_name:
            raise Exception('Identifier must have a class name!')

        model_class = self.model_classes[identifier.class_name]
        instance = self.model_instances.get(
        if isinstance(instance, model_class) and instance in self.session:
            return instance

        data = self.model_fixtures[]
        instance = model_class(**self.maybe_convert_values(model_class, data))
        self.model_instances[] = instance
        return instance

    def maybe_convert_values(self, model_class, data):
        ret = data.copy()
        for col_name, value in data.items():
            col = getattr(model_class, col_name)
            if isinstance(col, AssociationProxy) or col.impl.uses_objects:
                ret[col_name] = self.convert_identifiers(value)
            elif not hasattr(col, 'type'):
            elif isinstance(col.type, Date):
                if value in ('today', 'now', 'utcnow'):
                    ret[col_name] = utcnow().date
                    ret[col_name] = parse_datetime(value).date
            elif isinstance(col.type, DateTime):
                if value in ('now', 'utcnow'):
                    ret[col_name] = utcnow()
                elif not isinstance(value, datetime.datetime):
                    ret[col_name] = parse_datetime(value)
        return ret

    def convert_identifiers(self, identifiers):
        if isinstance(identifiers, list):
            return [self._create(identifier)
                    for identifier in flatten_identifiers(identifiers)]
        return self.convert_identifier(identifiers)

    def convert_identifier(self, identifier):
        result = [self._create(identifier)
                  for identifier in flatten_identifiers(identifier)]
        return result[0] if len(result) == 1 else result

    def _maybe_load_data(self, identifiers):
        class_names = {class_name for class_name, _ in identifiers}
        class_names = class_names.difference(self.loaded_class_names)
        if not class_names:

        for filename in os.listdir(self.fixtures_dir):
            path = os.path.join(self.fixtures_dir, filename)
            if os.path.isfile(path):
                class_name = filename[:filename.rfind('.')]
                if None in class_names or class_name in class_names:

    def _load_from_yaml(self, filename):
        if not self.env:
            self.env = Environment(loader=PackageLoader('tests',
            faker = Faker()
            self.env.globals['faker'] = faker

        template = self.env.get_template(filename)
        fixture_data = yaml.load(template.render(), Loader=yaml.FullLoader)

        class_name = filename[:filename.rfind('.')]
        for identifier_id, data in fixture_data.items():
            # FIXME check for dups
            self.class_name_lookup[identifier_id] = class_name
            self.model_fixtures[identifier_id] = data

def flatten_identifiers(identifiers):
    if isinstance(identifiers, str):
        identifiers = _convert_str(identifiers)
    if isinstance(identifiers, (list, tuple)):
        identifiers = _group_by_class_name(identifiers)

    ret = []
    for class_name, values in identifiers.items():
        for identifier in _flatten_csv_list(values):
            ret.append(Identifier(class_name, identifier))
    return ret

def _group_by_class_name(identifiers):
    ret = defaultdict(list)
    for v in identifiers:
        if isinstance(v, Identifier):
        elif isinstance(v, str):
            for identifier in _convert_str(v):
            raise Exception(f'Unexpected value {type(v)}')
    return ret

def _flatten_csv_list(identifiers):
    if isinstance(identifiers, str):
        identifiers = _convert_str(identifiers)
    return [identifier.strip()
            for items in identifiers
            for identifier in items.split(',')]

def _convert_str(value):
    ret = []
    prev = None
    while True:
        match =, prev.end() if prev else 0)
        if not match and not ret:
            return [Identifier(None, value)]
        elif not match:
            return ret

        prev = match