from __future__ import print_function

import collections, re, sys, time, traceback

try:
  import colorama
  colorama.init()
except ImportError:
  print('colorama not installed')

try:
  import peewee as pw
  import playhouse.migrate
except ImportError:
  print('peewee or herman not installed')
  # don't error, because setup needs to be able to run this to get the version

if sys.version_info >= (3,0):
  raw_input = input
  from collections.abc import Iterable
else:
  from collections import Iterable



DEBUG = False
PW3 = 'pw' in globals() and not hasattr(pw, 'Clause')

# peewee doesn't do defaults in the database - doh!
DIFF_DEFAULTS = False

__version__ = '3.7.6'


try:
  UNICODE_EXISTS = bool(type(unicode))
except NameError:
  unicode = lambda s: str(s)


###################
# Peewee 2 Shims
###################

if PW3:
  def extract_query_from_migration(migration):
    if isinstance(migration, Iterable):
      # Postgrsql has context first, MySql has context last :(
      ctx = next(obj for obj in migration if isinstance(obj, pw.Context))
    else:
      ctx = migration
    return [ctx.query()]

  def _table_name(model):
    return model._meta.table_name

  def _column_name(field):
    return field.column_name

  def _field_type(field):
    return field.field_type

  def _is_foreign_key(field):
    return hasattr(field, 'rel_model')

  def create_table(model):
    manager = pw.SchemaManager(model)
    ctx = manager._create_table()
    return [(''.join(ctx._sql), ctx._values)]

  def rename_table(migrator, before, after):
    migration = migrator.rename_table(before, after, with_context=True)
    return extract_query_from_migration(migration)

  def drop_table(migrator, name):
    migration = migrator.make_context().literal('DROP TABLE ').sql(pw.Entity(name))
    return extract_query_from_migration(migration)

  def create_index(model, fields, unique):
    manager = pw.SchemaManager(model)
    ctx = manager._create_index(pw.ModelIndex(model, fields, unique=unique))
    return [(''.join(ctx._sql), ctx._values)]

  def drop_index(migrator, model, index):
    migration = migrator.make_context().literal('DROP INDEX ').sql(pw.Entity(index.name))
    if is_mysql(model._meta.database):
      migration = migration.literal(' ON ').sql(pw.Entity(_table_name(model)))
    return extract_query_from_migration(migration)

  def create_foreign_key(field):
    manager = pw.SchemaManager(field.model)
    ctx = manager._create_foreign_key(field)
    return [(''.join(ctx._sql), ctx._values)]

  def drop_foreign_key(db, migrator, table_name, fk_name):
    drop_stmt = ' DROP FOREIGN KEY ' if is_mysql(db) else ' DROP CONSTRAINT '
    migration = migrator._alter_table(migrator.make_context(), table_name).literal(drop_stmt).sql(pw.Entity(fk_name))
    return extract_query_from_migration(migration)

  def drop_default(db, migrator, table_name, column_name, field):
    migration = migrator._alter_column(ctx, table_name, column_name).literal('DROP DEFAULT')
    return extract_query_from_migration(migration)

  def set_default(db, migrator, table_name, column_name, field):
    default = field.default
    if callable(default): default = default()
    migration = ( migrator.make_context()
      .literal('UPDATE ').sql(pw.Entity(table_name))
      .literal(' SET ').sql(pw.Expression(pw.Entity(column_name), pw.OP.EQ, field.db_value(default), flat=True))
      .literal(' WHERE ').sql(pw.Expression(pw.Entity(column_name), pw.OP.IS, pw.SQL('NULL'), flat=True))
    )
    return extract_query_from_migration(migration)

  def alter_add_column(db, migrator, ntn, column_name, field):
    migration = migrator.alter_add_column(ntn, column_name, field, with_context=True)
    to_run = extract_query_from_migration(migration)
    if is_mysql(db) and _is_foreign_key(field):
      to_run += create_foreign_key(field)
    return to_run

  def drop_not_null(migrator, ntn, defined_col):
    migration = migrator.drop_not_null(ntn, defined_col.name, with_context=True)
    return extract_query_from_migration(migration)

  def rename_column(db, migrator, table, ocn, ncn, field):
    if is_mysql(db):
      ctx = migrator.make_context()
      migration = migrator._alter_table(ctx, table).literal(' CHANGE ').sql(pw.Entity(ocn)).literal(' ').sql(field.ddl(ctx))
    else:
      migration = migrator.rename_column(table, ocn, ncn, with_context=True)
    return extract_query_from_migration(migration)

  def drop_column(db, migrator, table, column_name):
    migrator.explicit_delete_foreign_key = False
    migration = migrator.drop_column(table, column_name, cascade=False, with_context=True)
    return extract_query_from_migration(migration)

  def change_column_type(db, migrator, table_name, column_name, field):
    column_type = _field_type(field)
    ctx = migrator.make_context()
    if is_postgres(db):
      migration = migrator._alter_column(ctx, table_name, column_name).literal(' TYPE ').sql(field.ddl_datatype(ctx))
    elif is_mysql(db):
      migration = migrator._alter_table(ctx, table_name).literal(' MODIFY COLUMN ').sql(field.ddl(ctx))
    else:
      raise Exception('how do i change a column type for %s?' % db)

    return extract_query_from_migration(migration)

  def add_not_null(db, migrator, table, column_name, field):
    cmds = []
    if field.default is not None:
      cmds += set_default(db, migrator, table, column_name, field)
    if is_mysql(db):
      ctx = migrator.make_context()
      cmds.append(migrator._alter_table(ctx, table).literal(' MODIFY COLUMN ').sql(field.ddl(ctx)).query())
    else:
      migration = migrator.add_not_null(table, column_name, with_context=True)
      cmds += extract_query_from_migration(migration)
    return cmds

  def indexes_on_model(model):
    return [
      pw.IndexMetadata('', '', [_column_name(f) for f in idx._expressions], idx._unique, _table_name(model))
      for idx in model._meta.fields_to_index()
    ]

else:
  def normalize_op_to_clause(migrator, op):
    if isinstance(op, pw.Clause): return op
    playhouse.migrate
    kwargs = op.kwargs.copy()
    kwargs['generate'] = True
    ret = getattr(migrator, op.method)(*op.args, **kwargs)
    return ret

  def _table_name(cls):
    return cls._meta.db_table

  def _column_name(cls):
    return cls.db_column

  def _field_type(field):
    compiler = field.model_class._meta.database.compiler()
    return compiler.get_column_type(field.get_db_field())

  def _is_foreign_key(field):
    return isinstance(field, pw.ForeignKeyField)

  def create_table(cls):
    compiler = cls._meta.database.compiler()
    return [compiler.create_table(cls)]

  def rename_table(migrator, before, after):
    compiler = migrator.database.compiler()
    op = migrator.rename_table(before, after, generate=True)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def drop_table(migrator, table_name):
    compiler = migrator.database.compiler()
    return [compiler.parse_node(pw.Clause(pw.SQL('DROP TABLE'), pw.Entity(table_name)))]

  def create_index(model, fields, name):
    compiler = model._meta.database.compiler()
    return [compiler.create_index(model, fields, name)]

  def drop_index(migrator, model, index):
    compiler = migrator.database.compiler()
    op = migrator.drop_index(_table_name(model), index.name, generate=True)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def create_foreign_key(field):
    compiler = field.model_class._meta.database.compiler()
    return [compiler.create_foreign_key(field.model_class, field)]

  def drop_foreign_key(db, migrator, table_name, fk_name):
    drop_stmt = 'drop foreign key' if is_mysql(db) else 'DROP CONSTRAINT'
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL(drop_stmt), pw.Entity(fk_name))
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def drop_default(db, migrator, table_name, column_name, field):
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('DROP DEFAULT'))
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def set_default(db, migrator, table_name, column_name, field):
    default = field.default
    if callable(default): default = default()
    param = pw.Param(field.db_value(default))
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('SET DEFAULT'), param)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def alter_add_column(db, migrator, ntn, column_name, field):
    compiler = migrator.database.compiler()
    operation = migrator.alter_add_column(ntn, column_name, field, generate=True)
    to_run = normalize_whatever_junk_peewee_migrations_gives_you(migrator, operation)
    if is_mysql(db) and _is_foreign_key(field):
      to_run += create_foreign_key(field)
    return to_run

  def drop_not_null(migrator, ntn, defined_col):
    compiler = migrator.database.compiler()
    op = migrator.drop_not_null(ntn, defined_col.name, generate=True)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def rename_column(db, migrator, ntn, ocn, ncn, field):
    compiler = db.compiler()
    if is_mysql(db):
      junk = pw.Clause(
        pw.SQL('ALTER TABLE'), pw.Entity(ntn), pw.SQL('CHANGE'), pw.Entity(ocn), compiler.field_definition(field)
      )
    else:
      junk = migrator.rename_column(ntn, ocn, ncn, generate=True)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, junk)

  def drop_column(db, migrator, ntn, column_name):
    migrator.explicit_delete_foreign_key = False
    op = migrator.drop_column(ntn, column_name, generate=True, cascade=False)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def change_column_type(db, migrator, table_name, column_name, field):
    column_type = _field_type(field)
    if is_postgres(db):
      op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER'), field.as_entity(), pw.SQL('TYPE'), field.__ddl_column__(column_type))
    elif is_mysql(db):
      op = pw.Clause(*[pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('MODIFY')] + field.__ddl__(column_type))
    else:
      raise Exception('how do i change a column type for %s?' % db)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op)

  def normalize_whatever_junk_peewee_migrations_gives_you(migrator, junk):
    # sometimes a clause, sometimes an operation, sometimes a list mixed with clauses and operations
    # turn it into a list of (sql,params) tuples
    compiler = migrator.database.compiler()
    if not hasattr(junk, '__iter__'):
      junk = [junk]
    junk = [normalize_op_to_clause(migrator, o) for o in junk]
    junk = [compiler.parse_node(clause) for clause in junk]
    return junk

  def add_not_null(db, migrator, table, column_name, field):
    cmds = []
    compiler = db.compiler()
    if field.default is not None:
      # if default is a function, turn it into a value
      # this won't work on columns requiring uniquiness, like UUIDs
      # as all columns will share the same called value
      default = field.default() if hasattr(field.default, '__call__') else field.default
      op = pw.Clause(pw.SQL('UPDATE'), pw.Entity(table), pw.SQL('SET'), field.as_entity(), pw.SQL('='), default, pw.SQL('WHERE'), field.as_entity(), pw.SQL('IS NULL'))
      cmds.append(compiler.parse_node(op))
    if is_postgres(db) or is_sqlite(db):
      junk = migrator.add_not_null(table, column_name, generate=True)
      cmds += normalize_whatever_junk_peewee_migrations_gives_you(migrator, junk)
      return cmds
    elif is_mysql(db):
      op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table), pw.SQL('MODIFY'), compiler.field_definition(field))
      cmds.append(compiler.parse_node(op))
      return cmds
    raise Exception('how do i add a not null for %s?' % db)

  def indexes_on_model(model):
    return [pw.IndexMetadata('', '', [_column_name(f)], f.unique, _table_name(model)) for f in model._fields_to_index()]

####

def mark_fks_as_deferred(table_names):
  add_fks = []
  table_names_to_models = {_table_name(cls): cls for cls in all_models.keys() if _table_name(cls) in table_names}

  for model in table_names_to_models.values():
    for field in model._meta.sorted_fields:
      if _is_foreign_key(field):
        add_fks.append(field)
        if not field.deferred:
          field.__pwdbev__not_deferred = True
          field.deferred = True
  return add_fks

def calc_table_changes(existing_tables, ignore_tables=None):
  if ignore_tables:
    ignore_tables = set(ignore_tables) | globals()['ignore_tables']
  else:
    ignore_tables = globals()['ignore_tables']
  existing_tables = set(existing_tables)
  table_names_to_models = {unicode(_table_name(cls)):cls for cls in all_models.keys()}
  defined_tables = set(table_names_to_models.keys())
  adds = defined_tables - existing_tables - ignore_tables
  deletes = existing_tables - defined_tables - ignore_tables
  renames = {}
  for to_add in list(adds):
    cls = table_names_to_models[to_add]
    if hasattr(cls._meta, 'aka'):
      akas = cls._meta.aka
      if hasattr(akas, 'lower'):
        akas = [akas]
      for a in akas:
        a = unicode(a)
        if a in deletes:
          renames[a] = to_add
          adds.remove(to_add)
          deletes.remove(a)
          break
  add_fks = mark_fks_as_deferred(adds)
  return adds, add_fks, deletes, renames

def is_postgres(db):
  return isinstance(db, pw.PostgresqlDatabase)

def is_mysql(db):
  return isinstance(db, pw.MySQLDatabase)

def is_sqlite(db):
  return isinstance(db, pw.SqliteDatabase)

def auto_detect_migrator(db):
  if is_postgres(db):
    return playhouse.migrate.PostgresqlMigrator(db)
  if is_sqlite(db):
    return playhouse.migrate.SqliteMigrator(db)
  if is_mysql(db):
    return playhouse.migrate.MySQLMigrator(db)
  raise Exception("could not auto-detect migrator for %s - please provide one via the migrator kwarg" % repr(db.__class__.__name__))

_re_varchar = re.compile('^varchar[(]\\d+[)]$')
def normalize_column_type(t):
  t = t.lower()
  if t in ['serial', 'int', 'integer auto_increment', 'auto']: t = 'integer'
  if t in ['timestamp without time zone', 'datetime']: t = 'timestamp'
  if t in ['timestamp with time zone', 'datetime_tz']: t = 'timestamptz'
  if t in ['time without time zone']: t = 'time'
  if t in ['character varying']: t = 'varchar'
  if _re_varchar.match(t): t = 'varchar'
  if t in ['decimal', 'real', 'float']: t = 'numeric'
  if t in ['boolean']: t = 'bool'
  if t in ['bytea']: t = 'blob'
  return unicode(t)


def normalize_default(default):
  if default is None: return None
  if hasattr(default, 'lower'):
    default = unicode(default)
    if default.startswith('nextval('): return None
    default = default.split('::')[0]
    default = default.strip("'")
  return default

def can_convert(type1, type2):
  if type1=='array': return False
  return True

def are_data_types_equal(db, type_a, type_b):
  if type_a == type_b: return True
  type_a, type_b = sorted([type_a, type_b])
  if is_mysql(db) and type_a=='bool' and type_b=='tinyint': return True
  if is_postgres(db) and type_a=='char' and type_b=='character': return True
  return False

def column_def_changed(db, a, b):
  # b is the defined column
  return (
    a.null!=b.null or
    not are_data_types_equal(db, a.data_type, b.data_type) or
    (b.max_length is not None and a.max_length!=b.max_length) or
    (b.precision is not None and a.precision!=b.precision) or
    (b.scale is not None and a.scale!=b.scale) or
    a.primary_key!=b.primary_key or
    (DIFF_DEFAULTS and normalize_default(a.default)!=normalize_default(b.default))
  )

ColumnMetadata = collections.namedtuple('ColumnMetadata', (
  'name', 'data_type', 'null', 'primary_key', 'table', 'default', 'max_length', 'precision', 'scale'
))

def get_columns_by_table(db, schema=None):
  columns_by_table = collections.defaultdict(list)
  if is_postgres(db) or is_mysql(db):
    if schema is None and is_mysql(db):
      schema_check = 'c.table_schema=DATABASE()'
      params = []
    else:
      schema_check = 'c.table_schema=%s'
      params = [schema or 'public']
    sql = '''
        select
          c.column_name,
          c.data_type,
          c.is_nullable='YES' as is_nullable,
          coalesce(tc.constraint_type='PRIMARY KEY',false) as primary_key,
          c.table_name,
          c.column_default,
          c.character_maximum_length as max_length,
          c.numeric_precision,
          c.numeric_scale
        from information_schema.columns as c
        left join information_schema.key_column_usage as kcu
        on (c.table_name=kcu.table_name and c.table_schema=kcu.table_schema and c.column_name=kcu.column_name)
        left join information_schema.table_constraints as tc
        on (tc.table_name=kcu.table_name and tc.table_schema=kcu.table_schema and tc.constraint_name=kcu.constraint_name)
        where %s
        order by c.ordinal_position
    ''' % schema_check
    cursor = db.execute_sql(sql, params)
  else:
    raise Exception("don't know how to get columns for %s" % db)
  for row in cursor.fetchall():
    data_type = normalize_column_type(row[1])
    max_length = None if row[6]==4294967295 else row[6] # MySQL returns 4294967295L for LONGTEXT fields
    default = None if row[5] is not None and row[5].startswith('nextval') else row[5]
    precision = row[7] if data_type=='numeric' else None
    scale = row[8] if data_type=='numeric' else None
    column = ColumnMetadata(row[0], data_type, row[2], row[3], row[4], default, max_length, precision, scale)
    columns_by_table[column.table].append(column)
  return columns_by_table

ForeignKeyMetadata = collections.namedtuple('ForeignKeyMetadata', ('column', 'dest_table', 'dest_column', 'table', 'name'))

def get_foreign_keys_by_table(db, schema=None):
  fks_by_table = collections.defaultdict(list)
  if is_postgres(db):
    sql = """
      select kcu.column_name, ccu.table_name, ccu.column_name, tc.table_name, tc.constraint_name
      from information_schema.table_constraints as tc
      join information_schema.key_column_usage as kcu
        on (tc.constraint_name = kcu.constraint_name and tc.constraint_schema = kcu.constraint_schema)
      join information_schema.constraint_column_usage as ccu
        on (ccu.constraint_name = tc.constraint_name and ccu.constraint_schema = tc.constraint_schema)
      where tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = %s
    """
    cursor = db.execute_sql(sql, (schema or 'public',))
  elif is_mysql(db):
    sql = """
      select column_name, referenced_table_name, referenced_column_name, table_name, constraint_name
      from information_schema.key_column_usage
      where table_schema=database() and referenced_table_name is not null and referenced_column_name is not null
    """
    cursor = db.execute_sql(sql, [])
  elif is_sqlite(db):
    # does not work
    sql = """
      SELECT sql
        FROM (
              SELECT sql sql, type type, tbl_name tbl_name, name name
                FROM sqlite_master
               UNION ALL
              SELECT sql, type, tbl_name, name
                FROM sqlite_temp_master
             )
       WHERE type != 'meta'
         AND sql NOTNULL
         AND name NOT LIKE 'sqlite_%'
         AND sql LIKE '%REFERENCES%'
       ORDER BY substr(type, 2, 1), name
    """
    cursor = db.execute_sql(sql, [])
  else:
    raise Exception("don't know how to get FKs for %s" % db)
  for row in cursor.fetchall():
    fk = ForeignKeyMetadata(row[0], row[1], row[2], row[3], row[4])
    fks_by_table[fk.table].append(fk)
  return fks_by_table

def get_indexes_by_table(db, table, schema=None):
  # peewee's get_indexes returns the columns in an index in arbitrary order
  if is_postgres(db):
    query = '''
      select index_class.relname,
        idxs.indexdef,
        array_agg(table_attribute.attname order by array_position(index.indkey, table_attribute.attnum)),
        index.indisunique,
        table_class.relname
      from pg_catalog.pg_class index_class
      join pg_catalog.pg_index index on index_class.oid = index.indexrelid
      join pg_catalog.pg_class table_class on table_class.oid = index.indrelid
      join pg_catalog.pg_attribute table_attribute on table_class.oid = table_attribute.attrelid and table_attribute.attnum = any(index.indkey)
      join pg_catalog.pg_indexes idxs on idxs.tablename = table_class.relname and idxs.indexname = index_class.relname
      where table_class.relname = %s and table_class.relkind = %s and idxs.schemaname = %s
      group by index_class.relname, idxs.indexdef, index.indisunique, table_class.relname;
    '''
    cursor = db.execute_sql(query, (table, 'r', schema or 'public'))
    return [pw.IndexMetadata(*row) for row in cursor.fetchall()]
  else:
    return db.get_indexes(table, schema=schema)

def calc_column_changes(db, migrator, etn, ntn, existing_columns, defined_fields, existing_fks_by_column):
  defined_fields_by_column_name = {unicode(_column_name(f)):f for f in defined_fields}
  defined_columns = [ColumnMetadata(
    unicode(_column_name(f)),
    normalize_column_type(_field_type(f)),
    f.null,
    f.primary_key,
    unicode(ntn),
    f.default,
    f.max_length if hasattr(f, 'max_length') else None,
    f.max_digits if hasattr(f, 'max_digits') else None,
    f.decimal_places if hasattr(f, 'decimal_places') else None,
  ) for f in defined_fields if isinstance(f, pw.Field)]

  existing_cols_by_name = {c.name:c for c in existing_columns}
  defined_cols_by_name = {c.name:c for c in defined_columns}
  existing_col_names = set(existing_cols_by_name.keys())
  defined_col_names = set(defined_cols_by_name.keys())
  new_cols = defined_col_names - existing_col_names
  delete_cols = existing_col_names - defined_col_names
  rename_cols = {}
  for cn in list(new_cols):
    sc = defined_cols_by_name[cn]
    field = defined_fields_by_column_name[cn]
    if hasattr(field, 'akas'):
      for aka in field.akas:
        if aka in delete_cols:
          ec = existing_cols_by_name[aka]
          if can_convert(sc.data_type, ec.data_type):
            rename_cols[ec.name] = sc.name
            new_cols.discard(cn)
            delete_cols.discard(aka)

  alter_statements = []
  renames_new_to_old = {v:k for k,v in rename_cols.items()}
  not_new_columns = defined_col_names - new_cols

  # look for column metadata changes
  for col_name in not_new_columns:
    existing_col = existing_cols_by_name[renames_new_to_old.get(col_name, col_name)]
    defined_col = defined_cols_by_name[col_name]
    field = defined_fields_by_column_name[defined_col.name]
    if column_def_changed(db, existing_col, defined_col):
      len_alter_statements = len(alter_statements)

      different_type = existing_col.data_type != defined_col.data_type
      different_length = defined_col.max_length is not None and existing_col.max_length != defined_col.max_length
      different_precision = defined_col.precision is not None and existing_col.precision != defined_col.precision
      different_scale = defined_col.scale is not None and existing_col.scale != defined_col.scale

      should_cast = different_type and can_convert(existing_col.data_type, defined_col.data_type)
      should_recast = not different_type and (different_length or different_precision or different_scale)

      if existing_col.null and not defined_col.null:
        alter_statements += add_not_null(db, migrator, ntn, defined_col.name, field)
      if not existing_col.null and defined_col.null:
        alter_statements += drop_not_null(migrator, ntn, defined_col)
      if should_cast or should_recast:
        stmts = change_column_type(db, migrator, ntn, defined_col.name, field)
        alter_statements += stmts
      if DIFF_DEFAULTS:
        if normalize_default(existing_col.default) is not None and normalize_default(defined_col.default) is None:
          alter_statements += drop_default(db, migrator, ntn, defined_col.name, field)
        elif normalize_default(existing_col.default) != normalize_default(defined_col.default):
          alter_statements += set_default(db, migrator, ntn, defined_col.name, field)
      if not (len_alter_statements < len(alter_statements)):
        if existing_col.data_type == u'array':
          # type reporting for arrays is broken in peewee
          # it returns the underlying type of the array, not array
          # ignore array columns for now (HACK)
          pass
        else:
          raise Exception("In table %s I don't know how to change %s into %s" % (repr(ntn), existing_col, defined_col))

  # look for fk changes
  for col_name in not_new_columns:
    existing_column_name = renames_new_to_old.get(col_name, col_name)
    defined_field = defined_fields_by_column_name[col_name]
    existing_fk = existing_fks_by_column.get(existing_column_name)
    foreign_key = _is_foreign_key(defined_field)
    if foreign_key and not existing_fk and not (hasattr(defined_field, 'fake') and defined_field.fake):
      alter_statements += create_foreign_key(defined_field)
    if not foreign_key and existing_fk:
      alter_statements += drop_foreign_key(db, migrator, ntn, existing_fk.name)
  return new_cols, delete_cols, rename_cols, alter_statements


def calc_changes(db, ignore_tables=None, schema=None):
  migrator = None # expose eventually?
  if migrator is None:
    migrator = auto_detect_migrator(db)

  existing_tables = [unicode(t) for t in (db.get_tables(schema=schema) if schema else db.get_tables())]
  existing_indexes = {table:get_indexes_by_table(db, table, schema=schema) for table in existing_tables}
  existing_columns_by_table = get_columns_by_table(db, schema=schema)
  foreign_keys_by_table = get_foreign_keys_by_table(db, schema=schema)

  table_names_to_models = {_table_name(cls): cls for cls in all_models.keys()}

  to_run = []

  table_adds, add_fks, table_deletes, table_renames = calc_table_changes(existing_tables, ignore_tables=ignore_tables)
  table_renamed_from = {v: k for k, v in table_renames.items()}
  for tbl in table_adds:
    to_run += create_table(table_names_to_models[tbl])
  for field in add_fks:
    if hasattr(field, '__pwdbev__not_deferred') and field.__pwdbev__not_deferred:
      field.deferred = False
    to_run += create_foreign_key(field)
  for k, v in table_renames.items():
    to_run += rename_table(migrator, k, v)


  rename_cols_by_table = {}
  deleted_cols_by_table = {}
  for etn, ecols in existing_columns_by_table.items():
    if etn in table_deletes: continue
    ntn = table_renames.get(etn, etn)
    model = table_names_to_models.get(ntn)
    if not model: continue
    defined_fields = model._meta.sorted_fields
    
    # composite keys do not come from peewee w/ the primary key bit set
    if isinstance(model._meta.primary_key, pw.CompositeKey):
      for field_name in model._meta.primary_key.field_names:
        for field in defined_fields:
          if field_name==field.name and not field.primary_key:
            field.primary_key = True
    
    defined_column_name_to_field = {unicode(_column_name(f)):f for f in defined_fields}
    existing_fks_by_column = {fk.column:fk for fk in foreign_keys_by_table[etn]}
    adds, deletes, renames, alter_statements = calc_column_changes(db, migrator, etn, ntn, ecols, defined_fields, existing_fks_by_column)
    for column_name in adds:
      field = defined_column_name_to_field[column_name]
      to_run += alter_add_column(db, migrator, ntn, column_name, field)
      if not field.null:
        # alter_add_column strips null constraints
        # add them back after setting any defaults
        if field.default is not None:
          to_run += set_default(db, migrator, ntn, column_name, field)
        else:
          to_run.append(('-- adding a not null column without a default will fail if the table is not empty',[]))
        to_run += add_not_null(db, migrator, ntn, column_name, field)

    for column_name in deletes:
      fk = existing_fks_by_column.get(column_name)
      if fk:
        to_run += drop_foreign_key(db, migrator, ntn, fk.name)
      to_run += drop_column(db, migrator, ntn, column_name)
    for ocn, ncn in renames.items():
      field = defined_column_name_to_field[ncn]
      to_run += rename_column(db, migrator, ntn, ocn, ncn, field)
    to_run += alter_statements
    rename_cols_by_table[ntn] = renames
    deleted_cols_by_table[ntn] = deletes

  for ntn, model in table_names_to_models.items():
    etn = table_renamed_from.get(ntn, ntn)
    deletes = deleted_cols_by_table.get(ntn,set())
    existing_indexes_for_table = [i for i in existing_indexes.get(etn, []) if not any([(c in deletes) for c in i.columns])]
    to_run += calc_index_changes(db, migrator, existing_indexes_for_table, model, rename_cols_by_table.get(ntn, {}))

  '''
  to_run += calc_perms_changes($schema_tables, noop) unless $check_perms_for.empty?
  '''
  for tbl in table_deletes:
    to_run += drop_table(migrator, tbl)
  return to_run

def indexes_are_same(i1, i2):
  return unicode(i1.table)==unicode(i2.table) and i1.columns==i2.columns and i1.unique==i2.unique

def normalize_indexes(indexes):
  return [(unicode(idx.table), tuple(unicode(c) for c in idx.columns), idx.unique) for idx in indexes]


def calc_index_changes(db, migrator, existing_indexes, model, renamed_cols):
  to_run = []
  fields = list(model._meta.sorted_fields)
  fields_by_column_name = {_column_name(f):f for f in fields}
  pk_cols = set([unicode(_column_name(f)) for f in fields if f.primary_key])
  existing_indexes = [i for i in existing_indexes if not all([(unicode(c) in pk_cols) for c in i.columns])]
  normalized_existing_indexes = normalize_indexes(existing_indexes)
  existing_indexes_by_normalized_existing_indexes = dict(zip(normalized_existing_indexes, existing_indexes))
  normalized_existing_indexes = set(normalized_existing_indexes)
  defined_indexes = indexes_on_model(model)
  for fields, unique in model._meta.indexes:
    try:
      columns = [_column_name(model._meta.fields[fname]) for fname in fields]
    except KeyError as e:
      raise Exception("Index %s on %s references field %s in a multi-column index, but that field doesn't exist. (Be sure to use the field name, not the db_column name, when specifying a multi-column index.)" % ((fields, unique), model.__name__, repr(e.message)))
    defined_indexes.append(pw.IndexMetadata('', '', columns, unique, _table_name(model)))
  normalized_defined_indexes = set(normalize_indexes(defined_indexes))
  to_add = normalized_defined_indexes - normalized_existing_indexes
  to_del = normalized_existing_indexes - normalized_defined_indexes
  for index in to_del:
    index = existing_indexes_by_normalized_existing_indexes[index]
    to_run += drop_index(migrator, model, index)
  for index in to_add:
    to_run += create_index(model, [fields_by_column_name[col] for col in index[1]], index[2])
  return to_run

def evolve(db, interactive=True, ignore_tables=None, schema=None):
  if interactive:
    print((colorama.Style.BRIGHT + colorama.Fore.RED + 'Making updates to database: {}'.format(db.database) + colorama.Style.RESET_ALL))
  to_run = calc_changes(db, ignore_tables=ignore_tables, schema=schema)
  if not to_run:
    if interactive:
      print('Nothing to do... Your database is up to date!')
    return

  commit = True
  if interactive:
    commit = _confirm(db, to_run)
  _execute(db, to_run, interactive=interactive, commit=commit)


def _execute(db, to_run, interactive=True, commit=True):
  if interactive: print()
  try:
    with db.atomic() as txn:
      for sql, params in to_run:
        if interactive or DEBUG: print_sql(' %s; %s' % (sql, params or ''))
        if sql.strip().startswith('--'): continue
        db.execute_sql(sql, params)
      if interactive:
        print()
        print(
          (colorama.Style.BRIGHT + 'SUCCESS!' + colorama.Style.RESET_ALL) if commit else 'TEST PASSED - ROLLING BACK',
          colorama.Style.DIM + '-',
          'https://github.com/keredson/peewee-db-evolve' + colorama.Style.RESET_ALL
        )
        print()
      if not commit:
        txn.rollback()
  except Exception as e:
    print()
    print('------------------------------------------')
    print(colorama.Style.BRIGHT + colorama.Fore.RED + ' SQL EXCEPTION - ROLLING BACK ALL CHANGES' + colorama.Style.RESET_ALL)
    print('------------------------------------------')
    print()
    raise e

COLORED_WORDS = None

def init_COLORED_WORDS():
  global COLORED_WORDS
  COLORED_WORDS = [
    (colorama.Fore.GREEN, ['CREATE', 'ADD']),
    (colorama.Fore.YELLOW, ['ALTER', 'SET', 'RENAME']),
    (colorama.Fore.RED, ['DROP']),
    (colorama.Style.BRIGHT + colorama.Fore.BLUE, ['INTEGER','VARCHAR','TIMESTAMP','TEXT','SERIAL']),
    (colorama.Style.BRIGHT, ['BEGIN','COMMIT']),
    (colorama.Fore.CYAN, ['FOREIGN KEY', 'REFERENCES', 'UNIQUE']),
    (colorama.Style.BRIGHT + colorama.Fore.CYAN, ['PRIMARY KEY']),
    (colorama.Style.BRIGHT + colorama.Fore.MAGENTA, ['NOT NULL','NULL']),
    (colorama.Style.DIM, [' ON ', '(', ')', 'INDEX', 'TABLE', 'COLUMN', 'CONSTRAINT' ,' TO ',';']),
  ]

def print_sql(sql):
  if COLORED_WORDS is None: init_COLORED_WORDS()
  for color, patterns in COLORED_WORDS:
    for pattern in patterns:
      sql = sql.replace(pattern, color + pattern + colorama.Style.RESET_ALL)
  print(sql)


def _confirm(db, to_run):
  print()
  print("Your database needs the following %s:" % ('changes' if len(to_run)>1 else 'change'))
  print()
  if is_postgres(db): print_sql('  BEGIN TRANSACTION;\n')
  for sql, params in to_run:
    print_sql('  %s; %s' % (sql, params or ''))
  if is_postgres(db): print_sql('\n  COMMIT;')
  print()
  while True:
    print('Do you want to run %s? (%s)' % (('these commands' if len(to_run)>1 else 'this command'), ('type yes, no or test' if is_postgres(db) else 'yes or no')), end=' ')
    response = raw_input().strip().lower()
    if response=='yes' or (is_postgres(db) and response=='test'):
      break
    if response=='no':
      sys.exit(1)
  print('Running in', end=' ')
  for i in range(3):
    print('%i...' % (3-i), end=' ')
    sys.stdout.flush()
    time.sleep(1)
  print()
  return response=='yes'




all_models = {}
ignore_tables = set()

def register(model):
  if model.__module__=='playhouse.sqlite_ext':
    return
  if hasattr(model._meta, 'evolve') and not model._meta.evolve:
    ignore_tables.add(_table_name(model))
  else:
    all_models[model] = []

def unregister(model):
  del all_models[model]

def clear():
  all_models.clear()
  ignore_tables.clear()

def _add_model_hook():
  ModelBase = pw.BaseModel if hasattr(pw, 'BaseModel') else pw.ModelBase
  init = ModelBase.__init__
  def _init(*args, **kwargs):
    cls = args[0]
    fields = args[3]
    if '__module__' in fields:
      del fields['__module__']
    register(cls)
    init(*args, **kwargs)
  ModelBase.__init__ = _init

def _add_field_hook():
  init = pw.Field.__init__
  def _init(*args, **kwargs):
    self = args[0]
    if 'aka' in kwargs:
      akas = kwargs['aka']
      if hasattr(akas, 'lower'):
        akas = [akas]
      self.akas = akas
      del kwargs['aka']
    init(*args, **kwargs)
  pw.Field.__init__ = _init

def _add_fake_fk_field_hook():
  init = pw.ForeignKeyField.__init__
  def _init(*args, **kwargs):
    self = args[0]
    if 'fake' in kwargs:
      self.fake = kwargs['fake']
      del kwargs['fake']
    init(*args, **kwargs)
  pw.ForeignKeyField.__init__ = _init

def add_evolve():
  pw.Database.evolve = evolve

if 'pw' in globals():
  _add_model_hook()
  _add_field_hook()
  _add_fake_fk_field_hook()
  add_evolve()


__all__ = ['evolve']