from functools import lru_cache

from sqlparse import parse, tokens
from sqlparse.sql import Comment, IdentifierList, Parenthesis, Token


@lru_cache(maxsize=500)
def sql_fingerprint(query, hide_columns=True):
    """
    Simplify a query, taking away exact values and fields selected.

    Imperfect but better than super explicit, value-dependent queries.
    """
    parsed_queries = parse(query)

    if not parsed_queries:
        return ""

    parsed_query = sql_recursively_strip(parsed_queries[0])
    sql_recursively_simplify(parsed_query, hide_columns=hide_columns)
    return str(parsed_query).strip()


sql_deleteable_tokens = (
    tokens.Number,
    tokens.Number.Float,
    tokens.Number.Integer,
    tokens.Number.Hexadecimal,
    tokens.String,
    tokens.String.Single,
)


def sql_trim(node, idx=0):
    tokens = node.tokens
    count = len(tokens)
    min_count = abs(idx)

    while count > min_count and tokens[idx].is_whitespace:
        tokens.pop(idx)
        count -= 1


def sql_strip(node):
    ws_count = 0

    for token in node.tokens:
        if token.is_whitespace:
            token.value = "" if ws_count > 0 else " "
            ws_count += 1
        else:
            ws_count = 0


def sql_recursively_strip(node):
    for sub_node in node.get_sublists():
        sql_recursively_strip(sub_node)

    if isinstance(node, Comment):
        return node

    sql_strip(node)

    # strip duplicate whitespaces between parenthesis
    if isinstance(node, Parenthesis):
        sql_trim(node, 1)
        sql_trim(node, -2)

    return node


def sql_recursively_simplify(node, hide_columns=True):
    # Erase which fields are being updated in an UPDATE
    if node.tokens[0].value == "UPDATE":
        i_set = [i for (i, t) in enumerate(node.tokens) if t.value == "SET"][0]
        i_where = [
            i
            for (i, t) in enumerate(node.tokens)
            if t.is_group and t.tokens[0].value == "WHERE"
        ][0]
        middle = [Token(tokens.Punctuation, " ... ")]
        node.tokens = node.tokens[: i_set + 1] + middle + node.tokens[i_where:]

    # Ensure IN clauses with simple value in always simplify to "..."
    if node.tokens[0].value == "WHERE":
        in_token_indices = (i for i, t in enumerate(node.tokens) if t.value == "IN")
        for in_token_index in in_token_indices:
            parenthesis = next(
                t
                for t in node.tokens[in_token_index + 1 :]
                if isinstance(t, Parenthesis)
            )
            if all(
                getattr(t, "ttype", "") in sql_deleteable_tokens
                for t in parenthesis.tokens[1:-1]
            ):
                parenthesis.tokens[1:-1] = [Token(tokens.Punctuation, "...")]

    # Erase the names of savepoints since they are non-deteriministic
    if hasattr(node, "tokens"):
        # SAVEPOINT x
        if str(node.tokens[0]) == "SAVEPOINT":
            node.tokens[2].tokens[0].value = "`#`"
            return
        # RELEASE SAVEPOINT x
        elif len(node.tokens) >= 3 and node.tokens[2].value == "SAVEPOINT":
            node.tokens[4].tokens[0].value = "`#`"
            return
        # ROLLBACK TO SAVEPOINT X
        token_values = [getattr(t, "value", "") for t in node.tokens]
        if len(node.tokens) == 7 and token_values[:6] == [
            "ROLLBACK",
            " ",
            "TO",
            " ",
            "SAVEPOINT",
            " ",
        ]:
            node.tokens[6].tokens[0].value = "`#`"
            return

    # Erase volatile part of PG cursor name
    if node.tokens[0].value.startswith('"_django_curs_'):
        node.tokens[0].value = '"_django_curs_#"'

    prev_word_token = None

    for token in node.tokens:
        ttype = getattr(token, "ttype", None)

        # Detect IdentifierList tokens within an ORDER BY, GROUP BY or HAVING
        # clauses
        inside_order_group_having = match_keyword(
            prev_word_token, ["ORDER BY", "GROUP BY", "HAVING"]
        )
        replace_columns = not inside_order_group_having and hide_columns

        if isinstance(token, IdentifierList) and replace_columns:
            token.tokens = [Token(tokens.Punctuation, "...")]
        elif hasattr(token, "tokens"):
            sql_recursively_simplify(token, hide_columns=hide_columns)
        elif ttype in sql_deleteable_tokens:
            token.value = "#"
        elif getattr(token, "value", None) == "NULL":
            token.value = "#"

        if not token.is_whitespace:
            prev_word_token = token


def match_keyword(token, keywords):
    """
    Checks if the given token represents one of the given keywords
    """
    if not token:
        return False
    if not token.is_keyword:
        return False

    return token.value.upper() in keywords