"""Add contributed values table

Revision ID: 159ba85908fd
Revises: d5d88ac1d291
Create Date: 2019-11-01 15:39:50.970246

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.dialects import postgresql
from qcfractal.storage_sockets.models.sql_base import MsgpackExt
from qcfractal.storage_sockets.models.collections_models import ContributedValuesORM

import numpy as np

# revision identifiers, used by Alembic.
revision = "159ba85908fd"
down_revision = "d5d88ac1d291"
branch_labels = None
depends_on = None


def migrate_contributed_values_data():

    bind = op.get_bind()
    session = orm.Session(bind=bind)

    # Dataset and reaction datasets tables
    ds_ids_data = session.execute("select id, contributed_values_data from dataset;").fetchall()
    print(f"Migrating datasets with ids: {[ds[0] for ds in ds_ids_data]}")

    rds_ids_data = session.execute("select id, contributed_values_data from reaction_dataset;").fetchall()
    print(f"Migrating reaction datasets with ids: {[ds[0] for ds in rds_ids_data]}")

    ds_ids_data.extend(rds_ids_data)

    for ds in ds_ids_data:
        (ds_id, ds_contrib) = ds
        if ds_contrib is None:
            continue

        for key, dict_values in ds_contrib.items():

            idx, vals = [], []
            for key, value in dict_values["values"].items():
                idx.append(key)
                vals.append(value)

            dict_values["values"] = np.array(vals)
            dict_values["index"] = np.array(idx)

            cv = ContributedValuesORM(**dict_values)
            cv.collection_id = ds_id

            session.add(cv)

    session.commit()


def upgrade():

    # rename old column with data
    op.alter_column("dataset", "contributed_values", new_column_name="contributed_values_data")
    op.alter_column("reaction_dataset", "contributed_values", new_column_name="contributed_values_data")

    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table(
        "contributed_values",
        sa.Column("name", sa.String(), nullable=False),
        sa.Column("collection_id", sa.Integer(), nullable=False),
        sa.Column("citations", sa.JSON(), nullable=True),
        sa.Column("theory_level", sa.JSON(), nullable=False),
        sa.Column("theory_level_details", sa.JSON(), nullable=True),
        sa.Column("comments", sa.String(), nullable=True),
        sa.Column("values", MsgpackExt(), nullable=False),
        sa.Column("index", MsgpackExt(), nullable=False),
        sa.Column("external_url", sa.String(), nullable=True),
        sa.Column("doi", sa.String(), nullable=True),
        sa.Column("units", sa.String(), nullable=False),
        sa.Column("values_structure", sa.JSON(), nullable=True, default=lambda: {}),
        sa.ForeignKeyConstraint(["collection_id"], ["collection.id"], ondelete="cascade"),
        sa.PrimaryKeyConstraint("name", "collection_id"),
    )

    op.alter_column("contributed_values", "values_structure", server_default=None, nullable=False)

    migrate_contributed_values_data()

    op.drop_column("dataset", "contributed_values_data")
    op.drop_column("reaction_dataset", "contributed_values_data")


def downgrade():
    # ### Won't work on production data because data will be lost ###

    op.add_column(
        "reaction_dataset",
        sa.Column("contributed_values", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
    )
    op.add_column(
        "dataset",
        sa.Column("contributed_values", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
    )
    op.drop_table("contributed_values")