import pytest
from marshmallow import Schema, fields
from sqlalchemy import Column, Integer

from flask_resty import Api, GenericModelView, StrictRule
from flask_resty.testing import assert_response

# -----------------------------------------------------------------------------


@pytest.yield_fixture
def models(db):
    class Widget(db.Model):
        __tablename__ = "widgets"

        id = Column(Integer, primary_key=True)

    db.create_all()

    yield {"widget": Widget}

    db.drop_all()


@pytest.fixture
def schemas():
    class WidgetSchema(Schema):
        id = fields.Integer(as_string=True)

    return {"widget": WidgetSchema()}


@pytest.fixture
def views(models, schemas):
    class WidgetViewBase(GenericModelView):
        model = models["widget"]
        schema = schemas["widget"]

    class WidgetListView(WidgetViewBase):
        def get(self):
            return self.list()

        def post(self):
            return self.create(allow_client_id=True)

    class WidgetView(WidgetViewBase):
        def get(self, id):
            return self.retrieve(id)

    class CustomWidgetView(WidgetViewBase):
        def delete(self, id):
            return self.destroy(id)

        def update_item_raw(self, widget, data):
            return self.model(id=9)

        def delete_item_raw(self, widget):
            return self.model(id=9)

        def make_deleted_response(self, widget):
            return self.make_item_response(widget)

    return {
        "widget_list": WidgetListView,
        "widget": WidgetView,
        "custom_widget": CustomWidgetView,
    }


@pytest.fixture(autouse=True)
def data(db, models):
    db.session.add(models["widget"]())
    db.session.commit()


# -----------------------------------------------------------------------------


def test_api_prefix(app, views, client, base_client):
    api = Api(app, "/api")
    api.add_resource("/widgets", views["widget_list"])

    response = client.get("/widgets")
    assert_response(response, 200, [{"id": "1"}])

    response = base_client.get("/api/widgets")
    assert_response(response, 200, [{"id": "1"}])


def test_rule_without_slash(app, views, client):
    api = Api(app, "/api")
    api.add_resource("/widgets", views["widget_list"])

    response = client.get("/widgets")
    assert_response(response, 200)

    response = client.get("/widgets/")
    assert_response(response, 404)


def test_rule_with_slash(app, views, client):
    api = Api(app, "/api")
    api.add_resource("/widgets/", views["widget_list"])

    response = client.get("/widgets")
    assert_response(response, 308)

    response = client.get("/widgets/")
    assert_response(response, 200)


def test_no_append_slash(monkeypatch, app, views, client):
    monkeypatch.setattr(app, "url_rule_class", StrictRule)

    api = Api(app, "/api")
    api.add_resource("/widgets/", views["widget_list"])

    response = client.get("/widgets")
    assert_response(response, 404)

    response = client.get("/widgets/")
    assert_response(response, 200)


def test_create_client_id(app, views, client):
    api = Api(app)
    api.add_resource("/widgets", views["widget_list"], views["widget"])

    response = client.post("/widgets", data={"id": "100"})
    assert response.headers["Location"] == "http://localhost/widgets/100"
    assert_response(response, 201, {"id": "100"})


def test_create_no_location(app, views, client):
    views["widget_list"].get_location = lambda self, item: None

    api = Api(app)
    api.add_resource("/widgets", views["widget_list"], views["widget"])

    response = client.post("/widgets", data={})
    assert "Location" not in response.headers
    assert_response(response, 201, {"id": "2"})


def test_training_slash(app, views, client):
    api = Api(app)
    api.add_resource(
        "/widgets/", views["widget_list"], views["widget"], id_rule="<id>/"
    )

    response = client.post("/widgets/", data={"id": "100"})
    assert response.headers["Location"] == "http://localhost/widgets/100/"

    assert_response(response, 201, {"id": "100"})

    response = client.get("/widgets/100/")
    assert response.status_code == 200


def test_resource_rules(app, views, client):
    api = Api(app)
    api.add_resource(
        base_rule="/widget/<id>",
        base_view=views["widget"],
        alternate_rule="/widgets",
        alternate_view=views["widget_list"],
    )

    get_response = client.get("/widget/1")

    assert_response(get_response, 200, {"id": "1"})

    post_response = client.post("/widgets", data={})
    assert post_response.headers["Location"] == "http://localhost/widget/2"

    assert_response(post_response, 201, {"id": "2"})


def test_factory_pattern(app, views, client):
    api = Api()
    api.init_app(app)

    with pytest.raises(AssertionError, match="no application specified"):
        api.add_resource("/widgets", views["widget_list"])

    api.add_resource("/widgets", views["widget_list"], app=app)

    response = client.get("/widgets")
    assert_response(response, 200, [{"id": "1"}])


def test_view_func_wrapper(app, views):
    api = Api(app)
    api.add_resource("/widgets", views["widget_list"], views["widget"])

    # This is really a placeholder for asserting that e.g. custom New Relic
    # view information gets passed through.
    assert app.view_functions["WidgetView"].__name__ == "WidgetView"


def test_delete_return_item(app, views, client):
    api = Api(app)
    api.add_resource("/widgets/<int:id>", views["custom_widget"])

    response = client.delete("/widgets/1")
    assert_response(response, 200, {"id": "9"})