from os import listdir
import os.path as osp
from random import shuffle
import random
import shlex
import subprocess
import sqlite3
import datetime

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import psutil
import pandas as pd
from pandas.io.sql import DatabaseError
import psycopg2
from psycopg2.extensions import register_adapter, AsIs
from psycopg2.sql import SQL, Identifier
import torch
import torch.nn.functional as F


def print_with_time(x):
    now = datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3]
    print("%s: %s" % (now, x))


class RemainingTasksTaken(Exception):
    pass


class PopulationFinished(Exception):
    pass


class ExploitationNeeded(Exception):
    pass


class ExploitationOcurring(Exception):
    pass


class LossIsNaN(Exception):
    pass


def register_numpy_types():
    # Credit: https://github.com/musically-ut/psycopg2_numpy_ext
    """Register the AsIs adapter for following types from numpy:
      - numpy.int8
      - numpy.int16
      - numpy.int32
      - numpy.int64
      - numpy.float16
      - numpy.float32
      - numpy.float64
      - numpy.float128
    """
    for typ in ['int8', 'int16', 'int32', 'int64',
                'float16', 'float32', 'float64', 'float128',
                'bool_']:
        register_adapter(np.__getattribute__(typ), AsIs)


def get_task_ids_and_scores(connect_str_or_path, use_sqlite, population_id):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        command = """
                  SELECT task_id, score
                  FROM populations
                  WHERE population_id = ?
                  ORDER BY score DESC
                  """
    else:
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        command = """
                  SELECT task_id, score
                  FROM populations
                  WHERE population_id = %s
                  ORDER BY score DESC
                  """
    cur = conn.cursor()
    cur.execute(command, [population_id])
    results = cur.fetchall()
    cur.close()
    conn.close()
    task_ids = [result[0] for result in results]
    scores = [result[1] for result in results]
    return task_ids, scores


def get_col_from_populations(connect_str_or_path, use_sqlite,
                             population_id, column_name):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        command = "SELECT {} FROM populations WHERE population_id = ?"
        command = command.format(column_name)  # Warning: SQL injection
    else:
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        command = "SELECT {} FROM populations WHERE population_id = %s"
        command = SQL(command).format(Identifier(column_name))
    cur = conn.cursor()
    cur.execute(command, [population_id])
    column = cur.fetchall()
    cur.close()
    conn.close()
    column = [value[0] for value in column]
    return column


def update_table(connect_str_or_path, use_sqlite, table_name, key_value_pairs,
                 where_string=None, where_variables=None):
    values = [v.__name__ if callable(v) or isinstance(v, type) else v
              for v in key_value_pairs.values()]
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        fields = list(key_value_pairs.keys())
        placeholders = get_placeholders(len(key_value_pairs), "{} = ?")
        if where_string is None:
            where_string = "WHERE id = ?"
            row_id = key_value_pairs['id']
            where_variables = [row_id]
        command = " ".join(["UPDATE {}",
                            "SET {}".format(placeholders),
                            where_string])
        command = command.format(table_name, *fields)
    else:
        register_numpy_types()
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        table_name = Identifier(table_name)
        fields = [Identifier(field) for field in key_value_pairs.keys()]
        placeholders = get_placeholders(len(key_value_pairs), "{} = %s")
        if where_string is None:
            where_string = "WHERE id = %s"
            row_id = key_value_pairs['id']
            where_variables = [row_id]
        command = " ".join(["UPDATE {}",
                            "SET {}".format(placeholders),
                            where_string])
        command = SQL(command).format(table_name, *fields)
    parameters = values + where_variables
    cur = conn.cursor()
    cur.execute(command, parameters)
    conn.commit()
    cur.close()
    conn.close()


def update_task(connect_str_or_path, use_sqlite,
                population_id, task_id, key_value_pairs):
    if use_sqlite:
        where_string = "WHERE population_id = ? AND task_id = ?"
    else:
        where_string = "WHERE population_id = %s AND task_id = %s"
    where_variables = [population_id, task_id]
    update_table(connect_str_or_path, use_sqlite,
                 "populations", key_value_pairs,
                 where_string=where_string, where_variables=where_variables)


def get_a_task(connect_str_or_path, use_sqlite, population_id, interval_limit):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        command_get_task = """
                           SELECT task_id
                           FROM populations
                           WHERE population_id = ?
                           AND ready_for_exploitation = 0
                           AND active = 0
                           LIMIT 1
                           """
        command_lock_task = """
                            UPDATE populations
                            SET active = 1
                            WHERE population_id = ?
                            AND task_id = ?
                            """
        command_get_task_info = """
                                SELECT intervals_trained, seed_for_shuffling
                                FROM populations
                                WHERE population_id = ?
                                AND task_id = ?
                                """
    else:
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        command_get_task = """
                           SELECT task_id
                           FROM populations
                           WHERE population_id = %s
                           AND ready_for_exploitation = False
                           AND active = False
                           LIMIT 1
                           FOR SHARE
                           """
        command_lock_task = """
                            UPDATE populations
                            SET active = True
                            WHERE population_id = %s
                            AND task_id = %s
                            """
        command_get_task_info = """
                                SELECT intervals_trained, seed_for_shuffling
                                FROM populations
                                WHERE population_id = %s
                                AND task_id = %s
                                """
    cur = conn.cursor()
    cur.execute(command_get_task, [population_id])
    try:
        task_id = cur.fetchone()[0]
        cur.execute(command_lock_task, [population_id, task_id])
        conn.commit()
        cur.execute(command_get_task_info, [population_id, task_id])
        intervals_trained, seed_for_shuffling = cur.fetchone()
        cur.close()
        conn.close()
        return task_id, intervals_trained, seed_for_shuffling
    except TypeError:
        cur.close()
        conn.close()
        activities = get_col_from_populations(
            connect_str_or_path, use_sqlite, population_id, "active")
        any_are_active = [a for a in activities if a]
        if any_are_active:
            raise RemainingTasksTaken
        intervals_trained_col = get_col_from_populations(
            connect_str_or_path, use_sqlite, population_id,
            "intervals_trained")
        unfinished = [i for i in intervals_trained_col
                      if i < interval_limit]
        if not unfinished:
            raise PopulationFinished
        readys = get_col_from_populations(
            connect_str_or_path, use_sqlite, population_id,
            "ready_for_exploitation")
        not_ready = [r for r in readys if not r]
        if not not_ready:
            raise ExploitationNeeded
        else:
            raise ExploitationOcurring


def get_max_of_db_column(connect_str_or_path, use_sqlite, table_name,
                         column_name):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        cur = conn.cursor()
        parameters = [column_name, table_name]
        cur.execute("SELECT MAX({}) FROM {}".format(*parameters))
    else:
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        cur = conn.cursor()
        parameters = [Identifier(column_name), Identifier(table_name)]
        cur.execute(SQL("SELECT MAX({}) FROM {}").format(*parameters))
    max_value = cur.fetchone()[0]
    cur.close()
    conn.close()
    return max_value


def insert_into_table(connect_str_or_path, use_sqlite, table_name,
                      key_value_pairs):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        cur = conn.cursor()
        fields = key_value_pairs.keys()
        values = list(key_value_pairs.values())
        field_placeholders = get_placeholders(len(key_value_pairs), "{}")
        field_placeholders = "({})".format(field_placeholders)
        values_placeholders = get_placeholders(len(key_value_pairs), "?")
        values_placeholders = "({})".format(values_placeholders)
        # Warning: This command is vulnerable to SQL injection via
        # the fields variable.
        command = " ".join(["INSERT INTO populations", field_placeholders,
                            "VALUES", values_placeholders]).format(*fields)
    else:
        # TODO: Clean (see above block)
        db_connect_str = connect_str_or_path
        conn = psycopg2.connect(db_connect_str)
        register_numpy_types()
        table_name = Identifier(table_name)
        fields = [Identifier(field) for field in key_value_pairs.keys()]
        values = [v.__name__ if callable(v) or isinstance(v, type) else v
                  for v in key_value_pairs.values()]
        conn = psycopg2.connect(db_connect_str)
        cur = conn.cursor()
        insert_part = "INSERT INTO {}"
        field_positions = get_placeholders(len(key_value_pairs), "{}")
        fields_part = "({})".format(field_positions)
        value_positions = get_placeholders(len(key_value_pairs), "%s")
        values_part = "VALUES ({})".format(value_positions)
        command = insert_part + " " + fields_part + " " + values_part
        command = SQL(command).format(table_name, *fields)
    cur.execute(command, values)
    conn.commit()
    cur.close()
    conn.close()


def create_table(connect_str_or_path, use_sqlite, command):
    if use_sqlite:
        sqlite_path = connect_str_or_path
        conn = sqlite3.connect(sqlite_path)
        cur = conn.cursor()
        cur.execute(command)
        conn.commit()
        cur.close()
    else:
        conn = None
        try:
            db_connect_str = connect_str_or_path
            conn = psycopg2.connect(db_connect_str)
            cur = conn.cursor()
            cur.execute(command)
            conn.commit()
            cur.close()
        except (Exception, psycopg2.DatabaseError) as error:
            if "already exists" not in str(error):
                print(error)
        finally:
            if conn is not None:
                conn.close()


def get_placeholders(num, form):
    """
    Example:
        >>> get_placeholders(num=3, form="%s")
        '%s, %s, %s'
    """
    return ' '.join([form + "," for _ in range(num)])[:-1]


def create_new_population(connect_str_or_path, use_sqlite, population_size):
    if use_sqlite:
        command = """
                  CREATE TABLE populations (
                        population_id INTEGER,
                        task_id INTEGER,
                        intervals_trained INTEGER,
                        ready_for_exploitation INTEGER,
                        active INTEGER,
                        score REAL,
                        seed_for_shuffling INTEGER
                  )
                  """
        ready_for_exploitation = 0
        active = 0
    else:
        command = """
                  CREATE TABLE populations (
                        population_id INTEGER,
                        task_id INTEGER,
                        intervals_trained INTEGER,
                        ready_for_exploitation BOOLEAN,
                        active BOOLEAN,
                        score REAL,
                        seed_for_shuffling INTEGER
                  )
                  """
        ready_for_exploitation = False
        active = False
    table_name = "populations"
    try:
        latest_population_id = get_max_of_db_column(connect_str_or_path,
                                                    use_sqlite,
                                                    table_name,
                                                    "population_id")
        population_id = latest_population_id + 1
    except (sqlite3.OperationalError, psycopg2.ProgrammingError):
        create_table(connect_str_or_path, use_sqlite, command)
        population_id = 0
    for task_id in range(population_size):
        key_value_pairs = dict(population_id=population_id,
                               task_id=task_id,
                               intervals_trained=0,
                               ready_for_exploitation=ready_for_exploitation,
                               active=active,
                               score=None,
                               seed_for_shuffling=123)
        insert_into_table(connect_str_or_path, use_sqlite, table_name,
                          key_value_pairs)
    return population_id


def choose(x):
    return np.random.choice(x)


def print_separator():
    print("-"*80)


def get_database_path(here):
    return osp.join(osp.join(here, "logs"), "database.sqlite")


def load_sqlite_table(database_path, table_name):
    """Returns (table, connection). table is a pandas DataFrame."""
    conn = sqlite3.connect(database_path)
    try:
        df = pd.read_sql("SELECT * FROM %s" % table_name, conn)
        #  print("\nLoading %s table from SQLite3 database." % table_name)
    except DatabaseError as e:
        if 'no such table' in e.args[0]:
            print("\nNo such table: %s" % table_name)
            print("Create the table before loading it. " +
                  "Consider using the create_sqlite_table function")
            raise DatabaseError
        else:
            print(e)
            raise Exception("Failed to create %s table. Unknown error." %
                            table_name)
    return df, conn


def create_sqlite_table(database_path, table_name, table_header):
    """Returns (table, connection). table is a pandas DataFrame."""
    conn = sqlite3.connect(database_path)
    print("\nCreating %s table in SQLite3 database." % table_name)
    df = pd.DataFrame(columns=table_header)
    df.to_sql(table_name, conn, index=False)
    return df, conn


def create_log(filepath, headers):
    if not osp.exists(filepath):
        with open(filepath, 'w') as f:
            f.write(','.join(headers) + '\n')


def get_RAM():
    return psutil.virtual_memory().used


def git_hash():
    cmd = 'git log -n 1 --pretty="%h"'
    hash = subprocess.check_output(shlex.split(cmd)).strip()
    return hash


def transform_portrait(img):
    img = np.array(img, dtype=np.uint8)
    img = img[:, :, ::-1]  # RGB -> BGR
    img = img.astype(np.float64)
    mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
    img -= mean_bgr
    img = img.transpose(2, 0, 1)  # HxWxC --> CxHxW
    return img


def split_trn_val(num_train, valid_size=0.2, shuffle=False):
    indices = list(range(num_train))
    if shuffle:
        np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_train))
    trn_indices, val_indices = indices[split:], indices[:split]
    return trn_indices, val_indices


def cross_entropy2d(score, target, weight=None, size_average=True):
    log_p = F.log_softmax(score)

    # Flatten the score tensor
    n, c, h, w = score.size()
    log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    # Remove guesses corresponding to "unknown" labels
    # (labels that are less than zero)
    log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
    log_p = log_p.view(-1, c)

    # Remove "unknown" labels (labels that are less than zero)
    # Also, flatten the target tensor
    # TODO: Replace this entire function with nn.functional.cross_entropy
    #   with ignore_index set to -1.
    mask = target >= 0
    target = target[mask]

    loss = F.nll_loss(log_p, target, weight=weight, size_average=False)

    if size_average:
        loss /= mask.data.sum()
    return loss


def scoretensor2mask(scoretensor):
    """
    - scoretensor (3D torch tensor) (CxHxW): Each channel contains the scores
        for the corresponding category in the image.
    Returns a numpy array.
    """
    _, labels = scoretensor.max(0)  # Get labels w/ highest scores
    labels_np = labels.numpy().astype(np.uint8)
    mask = labels_np * 255
    return mask


def detransform_portrait(img, mean="voc"):
    """
    - img (torch tensor)
    Returns a numpy array.
    """
    if mean == "voc":
        mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
    else:
        raise ValueError("unknown mean")
    #  img = img.numpy().astype(np.float64)
    img = img.transpose((1, 2, 0))  # CxHxW --> HxWxC
    #  img *= 255
    img += mean_bgr
    img = img[:, :, ::-1]  # BGR -> RGB
    img = img.astype(np.uint8)
    return img


def detransform_mask(mask):
    #  mask = mask.numpy()
    mask = mask.astype(np.uint8)
    mask *= 255
    return mask


def mask_image(img, mask, opacity=1.00, bg=False):
    """
        - img (PIL)
        - mask (PIL)
        - opacity (float) (default: 1.00)
    Returns a PIL image.
    """
    blank = Image.new('RGB', img.size, color=0)
    if bg:
        masked_image = Image.composite(blank, img, mask)
    else:
        masked_image = Image.composite(img, blank, mask)
    if opacity < 1:
        masked_image = Image.blend(img, masked_image, opacity)
    return masked_image


def show_portrait_pred_mask(portrait, preds, mask, start_iteration,
                            evaluation_interval,
                            opacity=None, bg=False, fig=None):
    """
    Args:
        - portrait (torch tensor)
        - preds (list of np.ndarray): list of mask predictions
        - mask (torch tensor)
    A visualization function.
    Returns nothing.
    """
    # Gather images
    images = []
    titles = []
    cmaps = []

    #  ### Prepare portrait
    portrait_pil = Image.fromarray(portrait)
    images.append(portrait)
    titles.append("input")
    cmaps.append(None)

    #  ### Prepare predictions
    for i, pred in enumerate(preds):
        pred_pil = Image.fromarray(pred)
        if opacity:
            pred_pil = mask_image(portrait_pil, pred_pil, opacity, bg)
        images.append(pred_pil)
        titles.append("iter. %d" % (start_iteration + i * evaluation_interval))
        cmaps.append("gray")

    #  ### Prepare target mask
    if opacity:
        mask_pil = Image.fromarray(mask)
        mask = mask_image(portrait_pil, mask_pil, opacity, bg)
    images.append(mask)
    titles.append("target")
    cmaps.append("gray")

    # Show images
    cols = 5
    rows = int(np.ceil(len(images) / cols))
    w = 12
    h = rows * (w / cols + 1)
    figsize = (w, h)  # width x height
    plots(images, titles=titles, cmap=cmaps, rows=rows, cols=cols,
          figsize=figsize, fig=fig)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def get_fnames(d, random=False):
    fnames = [d + f for f in listdir(d) if osp.isfile(osp.join(d, f))]
    print("Number of files found in %s: %s" % (d, len(fnames)))
    if random:
        shuffle(fnames)
    return fnames


def rm_dir_and_ext(filepath):
    return filepath.split('/')[-1].split('.')[-2]


def get_flickr_id(portrait_fname):
    """
    Input (string): '../data/portraits/flickr/cropped/portraits/00074.jpg'
    Output (int): 74
    """
    return int(rm_dir_and_ext(portrait_fname))


def get_lines(fname):
    '''Read lines, strip, and split.'''
    with open(fname) as f:
        content = f.readlines()
    content = [x.strip().split() for x in content]
    return content


def hist(data, figsize=(6, 3)):
    plt.figure(figsize=figsize)
    plt.hist(data)
    plt.show()


def plot_portraits_and_masks(portraits, masks):
    assert len(portraits) == len(masks)
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    fig.tight_layout()
    for i, ax in enumerate(axes.flat):
        if i < 4:
            ax.imshow(portraits[i], interpolation="spline16")
        else:
            mask = gray2rgb(masks[i-4])
            ax.imshow(mask)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()


def gray2rgb(gray):
    w, h = gray.shape
    rgb = np.empty((w, h, 3), dtype=np.uint8)
    rgb[:, :, 2] = rgb[:, :, 1] = rgb[:, :, 0] = gray
    return rgb


def plots(imgs, figsize=(12, 12), rows=None, cols=None,
          interp=None, titles=None, cmap='gray',
          fig=None):
    if not isinstance(imgs, list):
        imgs = [imgs]
    imgs = [np.array(img) for img in imgs]
    if not isinstance(cmap, list):
        if imgs[0].ndim == 2:
            cmap = 'gray'
        cmap = [cmap] * len(imgs)
    if not isinstance(interp, list):
        interp = [interp] * len(imgs)
    n = len(imgs)
    if not rows and not cols:
        cols = n
        rows = 1
    elif not rows:
        rows = cols
    elif not cols:
        cols = rows
    if not fig:
        rows = int(np.ceil(len(imgs) / cols))
        w = 12
        h = rows * (w / cols + 1)
        figsize = (w, h)
        fig = plt.figure(figsize=figsize)
    fontsize = 13 if cols == 5 else 16
    fig.set_figheight(figsize[1], forward=True)
    fig.clear()
    for i in range(len(imgs)):
        sp = fig.add_subplot(rows, cols, i+1)
        if titles:
            sp.set_title(titles[i], fontsize=fontsize)
        plt.imshow(imgs[i], interpolation=interp[i], cmap=cmap[i])
        plt.axis('off')
        plt.subplots_adjust(0, 0, 1, 1, .1, 0)
        #  plt.tight_layout()
    if fig:
        fig.canvas.draw()