package com.y.fish.base.api.repository;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.escape.ArrayBasedUnicodeEscaper;
import com.y.fish.base.api.model.Storage;
import com.y.fish.base.api.sql.Pagination;
import com.y.fish.base.api.sql.SqlQuery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import java.lang.reflect.Field;
import java.sql.*;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;

/**
 * Created by myliang on 11/7/17.
 * 涉及到插入和更新,如果在存在created_at, updated_at字段,那么请不要设置任何值,系统会自动设置time更新数据库
 */
public abstract class BaseRepository<T> {

    Logger logger = LoggerFactory.getLogger(this.getClass());

    static ObjectMapper jsonMapper = new ObjectMapper();

    @Autowired
    @Qualifier("writeJdbcTemplate")
    JdbcTemplate writeJdbcTemplate;

    @Autowired
    @Qualifier("readJdbcTemplate")
    JdbcTemplate readJdbcTemplate;

    public abstract String tableName();
    public abstract String[] columnNames();
    public abstract int[] columnTypes();
    public abstract T convert(ResultSet rs) throws SQLException;

    public String selectSql() {
        return "select * from " + tableName() + "";
    }
    public String selectWhereSql () {
        return "select * from " + tableName() + " where ";
    }
    public String primaryKeyName() {
        return "id";
    }
    public String createdAtName() { return "created_at"; }
    public String updatedAtName() { return "updated_at"; }

    public T find(long id) {
        return find(primaryKeyName() + " = ?", new Object[]{id});
    }

    public T find(String sql, Object[] args) {
        try {
            sql = selectWhereSql() + sql;
            logger.info("find.sql: " + sql + ", args: " + jsonMapper.writeValueAsString(args));
            return readJdbcTemplate.queryForObject(sql, args, new BaseRowMapper());
        } catch (EmptyResultDataAccessException e) {
        } catch (JsonProcessingException e) {
        }
        return null;
    }

    public List<T> where(String sql, Object[] args) {
        if (sql == null || "".equalsIgnoreCase(sql)) {
            sql = selectSql() + sql;
        } else {
            sql = selectWhereSql() + sql;
        }
        try {
            logger.info("where.sql: " + sql + ", args: " + jsonMapper.writeValueAsString(args));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        return readJdbcTemplate.query(sql, args, new BaseRowMapper());
    }

    public List<T> all() {
        return where("", new Object[]{});
    }

    public long count(String sql, Object[] args) {
        return readJdbcTemplate.queryForObject(
                "select count(1) from " + tableName() + " where " + sql, args, Long.class);
    }

    public List<T> where(String sql, Object[] args, String order, int limit, int offset) {
        if (sql == null || "".equalsIgnoreCase(sql)) {
            sql = selectSql() + sql;
        } else {
            sql = selectWhereSql() + sql;
        }
        sql += " order by " + order + " limit " + limit + " offset " + offset * limit;
        try {
            logger.info("list.where.sql: " + sql + ", args: " + jsonMapper.writeValueAsString(args));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        return readJdbcTemplate.query(sql, args, new BaseRowMapper());
    }

    public Pagination<T> where(SqlQuery query) {
        Pagination<T> pagination = new Pagination();
        query.toSql();
        pagination.setTotal(count(query.getSql(), query.getParams()));
        pagination.setContent(where(query.getSql(), query.getParams(),
                query.getOrderBy(), query.pageRows(), query.pageOffset()));
        return pagination;
    }

    public T save(T t) throws Exception {
        String[] columnNames = columnNames();
        int[] columnTypes = columnTypes();
        List<Object> values = new ArrayList<>();
        List<String> keys = new ArrayList<>();
        List<Integer> types = new ArrayList<>();
        for (int i = 0; i < columnNames.length; i++) {
            String columnName = columnNames[i];
            String fieldName = changeColumnToFieldName(columnName);
            Field field = getDeclaredField(t.getClass(), fieldName);
            Object v = field.get(t);
            if (v != null) {
                keys.add(columnName);
                values.add(v);
                types.add(columnTypes[i]);
            }
        }

        long id = insert(keys, values, types.stream().mapToInt(type -> type).toArray());
        Field idField = getDeclaredField(t.getClass(), primaryKeyName());
        idField.set(t, id);
        return t;
    }

    public void update(T t, long id) throws Exception {
        T old = find(id);
        String[] columnNames = columnNames();
        List<Object> values = new ArrayList<>();
        List<String> keys = new ArrayList<>();
        for (int i = 0; i < columnNames.length; i++) {
            String columnName = columnNames[i];
            if (columnName.equals(createdAtName())) continue;

            String fieldName = changeColumnToFieldName(columnName);
            Field field = getDeclaredField(t.getClass(), fieldName);
            Object v = field.get(t);
            Object oldV = field.get(old);
            if (v != null && !v.equals(oldV)) {
                keys.add(columnName);
                values.add(v);
            }

        }
        update(keys.toArray(new String[keys.size()]), values.toArray(new Object[values.size()]), id);
    }

    public long insert(List<String> columns, List<Object> args, int[] types) {
        String insertColumns = primaryKeyName();
        String insertPlaceholders = "nextval('"+tableName()+"_id_seq')";
        for (String column : columns) {
            insertColumns += "," + column;
            insertPlaceholders += ",?";
        }

        insertColumns += ", " + createdAtName();
        insertPlaceholders += ", now()";
        insertColumns += ", " + updatedAtName();
        insertPlaceholders += ", now()";

        String sql = "insert into " + tableName() +
                " ("+insertColumns+") values ("+insertPlaceholders+")";

        KeyHolder keyHolder = new GeneratedKeyHolder();

        PreparedStatementCreatorFactory pscf = new PreparedStatementCreatorFactory(sql, types);
        pscf.setReturnGeneratedKeys(true);
        try {
            logger.info("insert.sql: {}, args: {}", sql, jsonMapper.writeValueAsString(args));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        writeJdbcTemplate.update(pscf.newPreparedStatementCreator(args), keyHolder);
        if (keyHolder.getKeyList() != null && !keyHolder.getKeyList().isEmpty()) {
            return (long) keyHolder.getKeyList().get(0).get(primaryKeyName());
        }
        return -1;
    }

    public void update(String[] columns, Object[] args, long id) {
        if (columns == null || columns.length <= 0) return;

        String sql = "update " + tableName() + " set ";
        sql += updatedAtName() + " = now(), ";

        sql += Arrays.stream(columns).map((column) -> column + " = ?").collect(Collectors.joining(","));
        sql += " where id = " + id;
        try {
            logger.info("update.sql: {}, args: {}", sql, jsonMapper.writeValueAsString(args));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        writeJdbcTemplate.update(sql, args);
    }

    public void delete(long id) {
        writeJdbcTemplate.update("delete from " + tableName() + " where "+primaryKeyName()+" = ?",
                new Object[] { id });
    }

    class BaseRowMapper implements RowMapper<T> {
        @Override
        public T mapRow(ResultSet rs, int rowNum) throws SQLException {
            return convert(rs);
        }
    }

    public static String changeColumnToFieldName(String columnName) {
        String[] array = columnName.split("_");
        StringBuffer sb = new StringBuffer(array[0]);
        for (int i = 1; i < array.length; i++) {
            String cn = array[i];
            sb.append(cn.substring(0, 1).toUpperCase()).append(cn.substring(1));
        }
        return sb.toString();
    }

    static Map<Class, Map<String, Field>> fieldMapCache = new HashMap();

    static Field getDeclaredField(Class target, String fieldName) throws NoSuchFieldException {
        if (!fieldMapCache.containsKey(target)) {
            fieldMapCache.put(target, new HashMap<>());
        }
        Map<String, Field> targetMap = fieldMapCache.get(target);
        if (!targetMap.containsKey(fieldName)) {
            Field field = target.getDeclaredField(fieldName);
            field.setAccessible(true);
            targetMap.put(fieldName, field);
        }
        return targetMap.get(fieldName);
    }

}