import asyncio

import sqlalchemy as sa
from sqlalchemy.schema import CreateTable

from aiopg.sa import create_engine

metadata = sa.MetaData()

users = sa.Table(
    'users_sa_transaction', metadata,
    sa.Column('id', sa.Integer, primary_key=True),
    sa.Column('name', sa.String(255))
)


async def create_sa_transaction_tables(conn):
    await conn.execute(CreateTable(users))


async def check_count_users(conn, *, count):
    s_query = sa.select(users).select_from(users)
    assert count == len(list(await (await conn.execute(s_query)).fetchall()))


async def success_transaction(conn):
    await check_count_users(conn, count=0)

    async with conn.begin():
        await conn.execute(sa.insert(users).values(id=1, name='test1'))
        await conn.execute(sa.insert(users).values(id=2, name='test2'))

    await check_count_users(conn, count=2)

    async with conn.begin():
        await conn.execute(sa.delete(users).where(users.c.id == 1))
        await conn.execute(sa.delete(users).where(users.c.id == 2))

    await check_count_users(conn, count=0)


async def fail_transaction(conn):
    await check_count_users(conn, count=0)

    trans = await conn.begin()

    try:
        await conn.execute(sa.insert(users).values(id=1, name='test1'))
        raise RuntimeError()

    except RuntimeError:
        await trans.rollback()
    else:
        await trans.commit()

    await check_count_users(conn, count=0)


async def success_nested_transaction(conn):
    await check_count_users(conn, count=0)

    async with conn.begin_nested():
        await conn.execute(sa.insert(users).values(id=1, name='test1'))

        async with conn.begin_nested():
            await conn.execute(sa.insert(users).values(id=2, name='test2'))

    await check_count_users(conn, count=2)

    async with conn.begin():
        await conn.execute(sa.delete(users).where(users.c.id == 1))
        await conn.execute(sa.delete(users).where(users.c.id == 2))

    await check_count_users(conn, count=0)


async def fail_nested_transaction(conn):
    await check_count_users(conn, count=0)

    async with conn.begin_nested():
        await conn.execute(sa.insert(users).values(id=1, name='test1'))

        tr_f = await conn.begin_nested()
        try:
            await conn.execute(sa.insert(users).values(id=2, name='test2'))
            raise RuntimeError()

        except RuntimeError:
            await tr_f.rollback()
        else:
            await tr_f.commit()

        async with conn.begin_nested():
            await conn.execute(sa.insert(users).values(id=2, name='test2'))

    await check_count_users(conn, count=2)

    async with conn.begin():
        await conn.execute(sa.delete(users).where(users.c.id == 1))
        await conn.execute(sa.delete(users).where(users.c.id == 2))

    await check_count_users(conn, count=0)


async def fail_first_nested_transaction(conn):
    trans = await conn.begin_nested()

    try:
        await conn.execute(sa.insert(users).values(id=1, name='test1'))

        async with conn.begin_nested():
            await conn.execute(sa.insert(users).values(id=2, name='test2'))

        async with conn.begin_nested():
            await conn.execute(sa.insert(users).values(id=3, name='test3'))

        raise RuntimeError()

    except RuntimeError:
        await trans.rollback()
    else:
        await trans.commit()

    await check_count_users(conn, count=0)


async def go():
    engine = await create_engine(user='aiopg',
                                 database='aiopg',
                                 host='127.0.0.1',
                                 password='passwd')
    async with engine:
        async with engine.acquire() as conn:
            await create_sa_transaction_tables(conn)

            await success_transaction(conn)
            await fail_transaction(conn)

            await success_nested_transaction(conn)
            await fail_nested_transaction(conn)
            await fail_first_nested_transaction(conn)


loop = asyncio.get_event_loop()
loop.run_until_complete(go())