package com.shzlw.poli.service;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.shzlw.poli.config.AppProperties;
import com.shzlw.poli.dto.Column;
import com.shzlw.poli.dto.FilterParameter;
import com.shzlw.poli.dto.QueryResult;
import com.shzlw.poli.dto.Table;
import com.shzlw.poli.util.CommonUtils;
import com.shzlw.poli.util.Constants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;

import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import java.io.IOException;
import java.lang.reflect.Field;
import java.sql.*;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.*;

@Service
public class JdbcQueryService {

    private static final Logger LOGGER = LoggerFactory.getLogger(JdbcQueryService.class);

    private static Map<Integer, String> JDBC_TYPE_MAP = new HashMap<>();

    @Autowired
    ObjectMapper mapper;

    @Autowired
    AppProperties appProperties;

    @PostConstruct
    public void init() throws IllegalAccessException {
        for (Field field : java.sql.Types.class.getFields()) {
            JDBC_TYPE_MAP.put((Integer) field.get(null), field.getName());
        }
    }

    public String ping(DataSource dataSource, String sql) {
        try (Connection con = dataSource.getConnection();
             PreparedStatement ps = con.prepareStatement(sql);
             ResultSet rs = ps.executeQuery();) {
            while (rs.next()) {
                break;
            }
            return Constants.SUCCESS;
        } catch (Exception e) {
            return CommonUtils.getSimpleError(e);
        }
    }

    public List<Table> getSchema(DataSource dataSource) {
        List<Table> tables = new ArrayList<>();
        try (Connection conn = dataSource.getConnection()) {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet rs = metaData.getTables(null, null, null, null);
            while (rs.next()) {
                String name = rs.getString("TABLE_NAME");
                String type = rs.getString("TABLE_TYPE");
                tables.add(new Table(name, type));
            }

            for (Table t : tables) {
                rs = metaData.getColumns(null, null, t.getName(), null);
                List<Column> columns = new ArrayList<>();

                while (rs.next()) {
                    String columnName = rs.getString("COLUMN_NAME");
                    int javaType = rs.getInt("DATA_TYPE");
                    String dbType = rs.getString("TYPE_NAME");
                    int length = rs.getInt("COLUMN_SIZE");
                    columns.add(new Column(columnName, JDBC_TYPE_MAP.get(javaType), dbType, length));
                }
                t.setColumns(columns);
            }
            return tables;
        } catch (Exception e) {
            return Collections.emptyList();
        }
    }

    public QueryResult queryByParams(
            DataSource dataSource,
            String sql,
            List<FilterParameter> filterParams,
            int resultLimit
    ) {
        if (dataSource == null) {
            return QueryResult.ofError(Constants.ERROR_NO_DATA_SOURCE_FOUND);
        } else if (StringUtils.isEmpty(sql)) {
            return QueryResult.ofError(Constants.ERROR_EMPTY_SQL_QUERY);
        }

        NamedParameterJdbcTemplate npjt = new NamedParameterJdbcTemplate(dataSource);
        Map<String, Object> namedParameters = getNamedParameters(filterParams);

        // Handle multiple SQL statements.
        // If there are multiple sql statements, only return query results from the last query.
        List<String> sqls = JdbcQueryServiceHelper.getQueryStatements(sql);
        int preQueryNumber = sqls.size() - 1;
        if (appProperties.getAllowMultipleQueryStatements()) {
            for (int i = 0; i < preQueryNumber; i++) {
                String parsedSql = JdbcQueryServiceHelper.parseSqlStatementWithParams(sqls.get(i), namedParameters);
                npjt.execute(parsedSql, (ps) -> ps.execute());
            }
        }

        String parsedSql = JdbcQueryServiceHelper.parseSqlStatementWithParams(sqls.get(preQueryNumber), namedParameters);
        return executeQuery(npjt, parsedSql, namedParameters, resultLimit);
    }

    public QueryResult executeQuery(DataSource dataSource, String sql, String contentType) {
        JdbcTemplate jt = new JdbcTemplate(dataSource);
        final int maxQueryResult = JdbcQueryServiceHelper.calculateMaxQueryResultLimit(appProperties.getMaximumQueryRecords(), Constants.QUERY_RESULT_NOLIMIT);

        QueryResult result = jt.query(sql, new Object[] {}, new ResultSetExtractor<QueryResult>() {
            @Nullable
            @Override
            public QueryResult extractData(ResultSet rs) {
                try {
                    ResultSetMetaData metadata = rs.getMetaData();
                    String[] columnNames = getColumnNames(metadata);
                    List<Column> columns = getColumnList(metadata);
                    String data;
                    if (Constants.CONTENT_TYPE_CSV.equals(contentType)) {
                        data = resultSetToCsvString(rs, columnNames, maxQueryResult);
                    } else {
                        data = resultSetToJsonString(rs, metadata, maxQueryResult);
                    }
                    return QueryResult.ofData(data, columns);
                } catch (Exception e) {
                    String error = CommonUtils.getSimpleError(e);
                    return QueryResult.ofError(error);
                }
            }
        });

        return result;
    }

    private QueryResult executeQuery(NamedParameterJdbcTemplate npjt,
                                     String sql,
                                     Map<String, Object> namedParameters,
                                     int resultLimit) {
        // Determine max query result
        final int maxQueryResult = JdbcQueryServiceHelper.calculateMaxQueryResultLimit(appProperties.getMaximumQueryRecords(), resultLimit);

        QueryResult result = npjt.query(sql, namedParameters, new ResultSetExtractor<QueryResult>() {
            @Nullable
            @Override
            public QueryResult extractData(ResultSet rs) {
                try {
                    ResultSetMetaData metadata = rs.getMetaData();
                    String[] columnNames = getColumnNames(metadata);
                    List<Column> columns = getColumnList(metadata);
                    String data = resultSetToJsonString(rs, metadata, maxQueryResult);
                    return QueryResult.ofData(data, columns);
                } catch (Exception e) {
                    String error = CommonUtils.getSimpleError(e);
                    return QueryResult.ofError(error);
                }
            }
        });

        return result;
    }

    public Map<String, Object> getNamedParameters(final List<FilterParameter> filterParams) {
        Map<String, Object> namedParameters = new HashMap<>();
        if (filterParams == null || filterParams.isEmpty()) {
            return namedParameters;
        }

        for (FilterParameter param : filterParams) {
            if (!JdbcQueryServiceHelper.isFilterParameterEmpty(param)) {
                String type = param.getType();
                String name = param.getParam();
                String value = param.getValue();

                if (type.equals(Constants.FILTER_TYPE_USER_ATTRIBUTE)) {
                    namedParameters.put(name, value);
                } else if (type.equals(Constants.FILTER_TYPE_SLICER)) {
                    String remark = param.getRemark();
                    if (remark == null) {
                        try {
                            List<String> array = Arrays.asList(mapper.readValue(value, String[].class));
                            if (!array.isEmpty()) {
                                namedParameters.put(name, array);
                            }
                        } catch (IOException e) {
                            LOGGER.warn("exception: {}", e);
                        }
                    }
                } else if (type.equals(Constants.FILTER_TYPE_SINGLE)) {
                    try {
                        String singleValue = mapper.readValue(value, String.class);
                        namedParameters.put(name, singleValue);
                    } catch (IOException e) {
                        LOGGER.warn("exception: {}", e);
                    }
                } else if (type.equals(Constants.FILTER_TYPE_DATE_PICKER)) {
                    try {
                        String dateStr = mapper.readValue(value, String.class);
                        if (!StringUtils.isEmpty(dateStr)) {
                            Date date = new SimpleDateFormat("yyyy-MM-dd").parse(dateStr);
                            namedParameters.put(name, date);
                        }
                    } catch (IOException | ParseException e) {
                        LOGGER.warn("exception: {}", e);
                    }
                }
                else {
                    throw new IllegalArgumentException("Unknown filter type");
                }
            }
        }
        return namedParameters;
    }

    private String[] getColumnNames(ResultSetMetaData metadata) throws SQLException {
        int columnCount = metadata.getColumnCount();
        String[] columnNames = new String[columnCount + 1];
        for (int i = 1; i <= columnCount; i++) {
            // Use column label to fetch the column alias instead of using column name.
            // If there is no alias, column label is the same as column name.
            String columnLabel = metadata.getColumnLabel(i);
            columnNames[i] = columnLabel;
        }
        return columnNames;
    }

    private List<Column> getColumnList(ResultSetMetaData metadata) throws SQLException {
        int columnCount = metadata.getColumnCount();
        List<Column> columns = new ArrayList<>();
        for (int i = 1; i <= columnCount; i++) {
            int columnType = metadata.getColumnType(i);;
            String dbType = metadata.getColumnTypeName(i);
            int length = metadata.getColumnDisplaySize(i);
            // Use column label to fetch the column alias instead of using column name.
            // If there is no alias, column label is the same as column name.
            String columnLabel = metadata.getColumnLabel(i);
            columns.add(new Column(columnLabel, JDBC_TYPE_MAP.get(columnType), dbType, length));
        }
        return columns;
    }

    private String resultSetToJsonString(ResultSet rs, ResultSetMetaData metadata, int maxQueryResult) throws SQLException {
        int columnCount = metadata.getColumnCount();
        ObjectMapper mapper = new ObjectMapper();
        ArrayNode array = mapper.createArrayNode();
        int rowCount = 0;
        while (rs.next()) {
            ObjectNode node = mapper.createObjectNode();
            for (int i = 1; i <= columnCount; i++) {
                String columnLabel = metadata.getColumnLabel(i);
                int columnType = metadata.getColumnType(i);
                switch (columnType) {
                    case java.sql.Types.VARCHAR:
                    case java.sql.Types.CHAR:
                    case java.sql.Types.LONGVARCHAR:
                        node.put(columnLabel, rs.getString(i));
                        break;
                    case java.sql.Types.TINYINT:
                    case java.sql.Types.SMALLINT:
                    case java.sql.Types.INTEGER:
                        node.put(columnLabel, rs.getInt(i));
                        break;
                    case java.sql.Types.NUMERIC:
                    case java.sql.Types.DECIMAL:
                        node.put(columnLabel, rs.getBigDecimal(i));
                        break;
                    case java.sql.Types.DOUBLE:
                    case java.sql.Types.FLOAT:
                    case java.sql.Types.REAL:
                        node.put(columnLabel, rs.getDouble(i));
                        break;
                    case java.sql.Types.BOOLEAN:
                    case java.sql.Types.BIT:
                        node.put(columnLabel, rs.getBoolean(i));
                        break;
                    case java.sql.Types.BIGINT:
                        node.put(columnLabel, rs.getLong(i));
                        break;
                    case java.sql.Types.NVARCHAR:
                    case java.sql.Types.NCHAR:
                        node.put(columnLabel, rs.getNString(i));
                        break;
                    default:
                        // Unhandled types
                        node.put(columnLabel, rs.getString(i));
                        break;
                }
            }
            array.add(node);
            rowCount++;
            if (maxQueryResult != Constants.QUERY_RESULT_NOLIMIT && rowCount >= maxQueryResult) {
                break;
            }
        }
        return array.toString();
    }


    private String resultSetToCsvString(ResultSet rs, String[] columnNames, int maxQueryResult) throws SQLException {
        int columnCount = columnNames.length - 1;
        StringBuilder sb = new StringBuilder();
        for (int i = 1; i <= columnCount; i++) {
            sb.append(columnNames[i]).append(",");
        }
        sb.append("\r\n");

        int rowCount = 0;
        while (rs.next()) {
            for (int i = 1; i <= columnCount; i++) {
                // TODO: handle quotation marks
                sb.append(rs.getString(i)).append(",");
            }
            sb.append("\r\n");
            rowCount++;
            if (maxQueryResult != Constants.QUERY_RESULT_NOLIMIT && rowCount >= maxQueryResult) {
                break;
            }
        }
        return sb.toString();
    }
}