# -*- coding: utf-8 -*-

import threading

import flask
import pytest
import webtest
from flask_sqlalchemy import SQLAlchemy

from nplusone.core import signals
from nplusone.core import exceptions
from nplusone.ext.wsgi import NPlusOneMiddleware
import nplusone.ext.sqlalchemy  # noqa

from tests import utils


def get_worker():
    return str(threading.current_thread().ident)


@pytest.fixture(scope='module', autouse=True)
def setup():
    signals.get_worker = get_worker


@pytest.fixture
def db():
    return SQLAlchemy()


@pytest.fixture
def models(db):
    return utils.make_models(db.Model)


@pytest.fixture()
def objects(db, app, models):
    hobby = models.Hobby()
    address = models.Address()
    user = models.User(addresses=[address], hobbies=[hobby])
    db.session.add(user)
    db.session.commit()
    db.session.close()


@pytest.fixture
def app(db, models):
    app = flask.Flask(__name__)
    app.config['TESTING'] = True
    app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    db.init_app(app)
    with app.app_context():
        db.create_all()
        yield app


@pytest.fixture
def routes(app, models, wrapper):
    @app.route('/many_to_one/')
    def many_to_one():
        users = models.User.query.all()
        return str(users[0].addresses)

    @app.route('/many_to_one_one/')
    def many_to_one_one():
        user = models.User.query.filter_by(id=1).one()
        return str(user.addresses)


@pytest.fixture
def wrapper(app):
    return NPlusOneMiddleware(app)


@pytest.fixture
def client(routes, wrapper):
    return webtest.TestApp(wrapper)


class TestNPlusOneMiddleware:

    def test_many_to_one(self, objects, client):
        with pytest.raises(exceptions.NPlusOneError):
            client.get('/many_to_one/')

    def test_many_to_one_one(self, objects, client):
        client.get('/many_to_one_one/')