package com.dfire.platform.alchemy.service.util;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.config.Lex;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlInsert;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;

import com.google.common.collect.Lists;

/**
 * @author congbai
 * @date 2019/6/3
 */
public class SqlParseUtil {

    private static final SqlParser.Config CONFIG = SqlParser.configBuilder().setLex(Lex.MYSQL).build();

    public static void parse(List<String> sqls, List<String> sources, List<String> udfs, List<String> sinks)
        throws SqlParseException {
        for (String sql : sqls) {
            SqlParser sqlParser = SqlParser.create(sql, CONFIG);
            SqlNode sqlNode = sqlParser.parseStmt();
            SqlKind kind = sqlNode.getKind();
            switch (kind){
                case INSERT:
                    SqlInsert sqlInsert = (SqlInsert)sqlNode;
                    addSink(sinks, findSinkName(sqlInsert));
                    SqlSelect source = (SqlSelect) sqlInsert.getSource();
                    parseSource(source, sources, udfs);
                    break;
                case SELECT:
                    parseSource((SqlSelect) sqlNode, sources, udfs);
                    break;
                default:
                    throw new IllegalArgumentException("It must be an insert SQL, sql:" + sql);
            }
        }
    }

    public static List<String> findQuerySql(List<String> sqls)
        throws SqlParseException {
        List<String> newSqls = new ArrayList<>(sqls.size());
        for (String sql : sqls) {
            SqlParser sqlParser = SqlParser.create(sql, CONFIG);
            SqlNode sqlNode = sqlParser.parseStmt();
            if (sqlNode.getKind() != SqlKind.INSERT) {
                throw new IllegalArgumentException("It must be an insert SQL, sql:" + sql);
            }
            SqlInsert sqlInsert = (SqlInsert)sqlNode;
            newSqls.add(sqlInsert.getSource().toString());
        }
        return newSqls;
    }

    private static void addSink(List<String> newSinks, String sinkName) {
        if (!newSinks.contains(sinkName)) {
            newSinks.add(sinkName);
        }
    }

    private static void addUdf(List<String> newUdfs, String udfName) {
        if (!newUdfs.contains(udfName)) {
            newUdfs.add(udfName);
        }
    }

    private static void addSource(List<String> newSources, String sourceName) {
        if (!newSources.contains(sourceName)) {
            newSources.add(sourceName);
        }
    }

    private static void parseSource(SqlSelect sqlSelect, List<String> sources, List<String> udfs)
        throws SqlParseException {
        SqlNodeList selectList = sqlSelect.getSelectList();
        SqlNode from = sqlSelect.getFrom();
        SqlNode where = sqlSelect.getWhere();
        SqlNode having = sqlSelect.getHaving();
        parseSelectList(selectList, sources, udfs);
        parseFrom(from, sources, udfs);
        parseFunction(where, udfs);
        parseFunction(having, udfs);
    }

    /**
     * 解析select 字段中的函数
     *
     * @param sqlNodeList
     * @param udfs
     */
    private static void parseSelectList(SqlNodeList sqlNodeList, List<String> sources, List<String> udfs)
        throws SqlParseException {
        for (SqlNode sqlNode : sqlNodeList) {
            parseSelect(sqlNode, sources, udfs);
        }
    }

    private static void parseFrom(SqlNode from, List<String> sources, List<String> udfs) throws SqlParseException {
        SqlKind sqlKind = from.getKind();
        switch (sqlKind) {
            case IDENTIFIER:
                SqlIdentifier identifier = (SqlIdentifier)from;
                addSource(sources, identifier.getSimple());
                break;
            case AS:
                SqlBasicCall sqlBasicCall = (SqlBasicCall)from;
                parseFrom(sqlBasicCall.operand(0), sources, udfs);
                break;
            case SELECT:
                parseSource((SqlSelect)from, sources, udfs);
                break;
            case JOIN:
                SqlJoin sqlJoin = (SqlJoin)from;
                SqlNode left = sqlJoin.getLeft();
                SqlNode right = sqlJoin.getRight();
                parseFrom(left, sources, udfs);
                parseFrom(right, sources, udfs);
                break;
            case LATERAL:
                SqlBasicCall basicCall = (SqlBasicCall)from;
                SqlNode childNode = basicCall.getOperands()[0];
                parseFunction(childNode, udfs);
            default:
        }
    }

    private static void parseFunction(SqlNode sqlNode, List<String> udfs) {
        if (sqlNode instanceof SqlBasicCall) {
            SqlBasicCall sqlBasicCall = (SqlBasicCall)sqlNode;
            SqlOperator operator = sqlBasicCall.getOperator();
            if (operator instanceof SqlFunction) {
                SqlFunction sqlFunction = (SqlFunction)operator;
                SqlFunctionCategory category = sqlFunction.getFunctionType();
                switch (category) {
                    case USER_DEFINED_FUNCTION:
                    case USER_DEFINED_SPECIFIC_FUNCTION:
                    case USER_DEFINED_TABLE_FUNCTION:
                    case USER_DEFINED_TABLE_SPECIFIC_FUNCTION:
                        addUdf(udfs, sqlFunction.getName());
                        break;
                    default:
                }
            } else {
                parseFunction(sqlBasicCall.operand(0), udfs);
            }
            // 查询嵌套的函数
            SqlNode[] nodes = sqlBasicCall.getOperands();
            if(nodes != null && nodes.length > 0){
                for(SqlNode node : nodes){
                    parseFunction(node, udfs);
                }
            }
        }
    }

    private static void parseSelect(SqlNode sqlNode, List<String> sources, List<String> udfs) throws SqlParseException {
        SqlKind sqlKind = sqlNode.getKind();
        switch (sqlKind) {
            case IDENTIFIER:
                break;
            case AS:
                SqlNode firstNode = ((SqlBasicCall)sqlNode).operand(0);
                parseSelect(firstNode, sources, udfs);
                break;
            case SELECT:
                parseSource((SqlSelect)sqlNode, sources, udfs);
                break;
            default:
                parseFunction(sqlNode, udfs);

        }
    }

    public static String parseSinkName(String sql) throws SqlParseException {
        SqlParser sqlParser = SqlParser.create(sql, CONFIG);
        SqlNode sqlNode = sqlParser.parseStmt();
        SqlKind sqlKind = sqlNode.getKind();
        if (sqlKind != SqlKind.INSERT) {
            throw new IllegalArgumentException("It must be an insert SQL, sql:" + sql);
        }
        return findSinkName((SqlInsert)sqlNode);
    }

    private static String findSinkName(SqlInsert sqlInsert) {
        SqlNode target = sqlInsert.getTargetTable();
        SqlKind targetKind = target.getKind();
        if (targetKind != SqlKind.IDENTIFIER) {
            throw new IllegalArgumentException("invalid insert SQL, sql:" + sqlInsert.toString());
        }
        SqlIdentifier identifier = (SqlIdentifier)target;
        return identifier.getSimple();
    }
}