# -*- coding: utf-8 -*-

import pytest
import peewee as pw

from nplusone.core import signals
import nplusone.ext.peewee  # noqa

from tests.utils import Bunch


@pytest.fixture
def db():
    return pw.SqliteDatabase(':memory:')


@pytest.fixture
def Base(db):
    class Base(pw.Model):
        class Meta:
            database = db
    return Base


@pytest.fixture
def models(Base):

    class Hobby(Base):
        pass

    class User(Base):
        hobbies = pw.ManyToManyField(Hobby, backref='users')

    class Address(Base):
        user = pw.ForeignKeyField(User, backref='addresses')

    return Bunch(
        Hobby=Hobby,
        User=User,
        Address=Address,
    )


@pytest.fixture
def session(db, models):
    db.create_tables(
        [
            models.User,
            models.Address,
            models.Hobby,
            models.User.hobbies.get_through_model()
        ],
        safe=True,
    )
    with db.atomic() as transaction:
        yield transaction


@pytest.fixture()
def objects(models, session):
    user = models.User.create(id=1)
    hobby = models.Hobby.create(id=1)
    hobby.users.add(user)
    address = models.Address.create(id=1, user=user)
    return Bunch(
        user=user,
        hobby=hobby,
        address=address,
    )


class TestManyToOne:

    def test_many_to_one(self, models, session, objects, calls, lazy_listener):
        users = models.User.select()
        list(users[0].addresses)
        assert len(calls) == 1
        call = calls[0]
        assert call.objects == (models.User, 'User:1', 'addresses')
        assert 'users[0].addresses' in ''.join(call.frame[4])
        assert lazy_listener.parent.notify

    def test_many_to_one_get(self, models, session, objects, calls, lazy_listener):
        user = models.User.get()
        list(user.addresses)
        assert len(calls) == 1
        call = calls[0]
        assert call.objects == (models.User, 'User:1', 'addresses')
        assert 'user.addresses' in ''.join(call.frame[4])
        assert not lazy_listener.parent.notify.called

    def test_many_to_one_prefetch(self, models, session, objects, calls, lazy_listener):
        users = pw.prefetch(
            models.User.select(),
            models.Address.select(),
        )
        list(users[0].addresses)
        assert len(calls) == 0

    def test_many_to_one_ignore(self, models, session, objects, calls):
        user = models.User.select().first()
        with signals.ignore(signals.lazy_load):
            user.addresses
        assert len(calls) == 0

    def test_many_to_one_reverse(self, models, session, objects, calls):
        address = models.Address.select().first()
        address.user
        assert len(calls) == 1
        call = calls[0]
        assert call.objects == (models.Address, 'Address:1', 'user')
        assert 'address.user' in ''.join(call.frame[4])

    def test_many_to_one_reverse_join(self, models, session, objects, calls):
        address = models.Address.select(
            models.Address,
            models.User,
        ).join(
            models.User
        ).first()
        address.user
        assert len(calls) == 0

    def test_many_to_one_reverse_prefetch(self, models, session, objects, calls):
        addresses = pw.prefetch(
            models.Address.select(),
            models.User.select(),
        )
        addresses[0].user
        assert len(calls) == 0


class TestManyToMany:

    def test_many_to_many(self, models, session, objects, calls):
        users = models.User.select()
        list(users[0].hobbies)
        assert len(calls) == 1
        call = calls[0]
        assert call.objects == (models.User, 'User:1', 'hobbies')
        assert 'users[0].hobbies' in ''.join(call.frame[4])

    def test_many_to_many_reverse(self, models, session, objects, calls):
        hobby = models.Hobby.select().first()
        list(hobby.users)
        assert len(calls) == 1
        call = calls[0]
        assert call.objects == (models.Hobby, 'Hobby:1', 'users')
        assert 'hobby.users' in ''.join(call.frame[4])