package org.apache.ibatis.plugin;

import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.annotations.CryptField;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.defaults.DefaultSqlSession;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 项目:mybatis-crypt
 * 包名:org.apache.ibatis.plugin
 * 功能:数据库数据脱敏
 * 加解密算法推荐:aes192 + base64
 * 时间:2017-11-22
 * 作者:miaoxw
 */
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class CryptInterceptor implements Interceptor {

    private static final String PARAM = "param";

    private static final String PARAM_TYPE_LIST = "list";

    private static final String PARAM_TYPE_COLLECTION = "collection";

    private static final String MAPPEDSTATEMENT_ID_SEPERATOR = ".";

    /**
     * 适用于加密判断
     */
    private static final ConcurrentHashMap<String, Set<String>> METHOD_PARAM_ANNOTATIONS_MAP = new ConcurrentHashMap<>();
    /**
     * 适用于解密判断
     */
    private static final ConcurrentHashMap<String, Boolean> METHOD_ANNOTATIONS_MAP = new ConcurrentHashMap<>();

    public CryptInterceptor() {

    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        // 入参
        Object parameter = args[1];
        MappedStatement statement = (MappedStatement) args[0];
        // 判断是否需要解析
        if (!isNotCrypt(parameter)) {
            // 单参数 string
            if (parameter instanceof String) {
                args[1] = stringEncrypt((String) parameter, getParameterAnnotations(statement));
                // 单参数 list
            } else if (parameter instanceof DefaultSqlSession.StrictMap) {
                DefaultSqlSession.StrictMap<Object> strictMap = (DefaultSqlSession.StrictMap<Object>) parameter;
                for (Map.Entry<String, Object> entry : strictMap.entrySet()) {
                    if (entry.getKey().contains(PARAM_TYPE_COLLECTION)) {
                        continue;
                    }
                    if (entry.getKey().contains(PARAM_TYPE_LIST)) {
                        Set<String> set = getParameterAnnotations(statement);
                        listEncrypt((List) entry.getValue(), !set.isEmpty());
                    }
                }
                // 多参数
            } else if (parameter instanceof MapperMethod.ParamMap) {
                MapperMethod.ParamMap<Object> paramMap = (MapperMethod.ParamMap<Object>) parameter;
                Set<String> set = getParameterAnnotations(statement);
                boolean setEmpty = set.isEmpty();
                // 解析每一个参数
                for (Map.Entry<String, Object> entry : paramMap.entrySet()) {
                    // 判断不需要解析的类型 不解析map
                    if (isNotCrypt(entry.getValue()) || entry.getValue() instanceof Map || entry.getKey().contains(PARAM)) {
                        continue;
                    }
                    // 如果string
                    if (entry.getValue() instanceof String) {
                        entry.setValue(stringEncrypt(entry.getKey(), (String) entry.getValue(), set));
                        continue;
                    }
                    boolean isSetValue = !setEmpty && set.contains(entry.getKey());
                    // 如果 list
                    if (entry.getValue() instanceof List) {
                        listEncrypt((List) entry.getValue(), isSetValue);
                        continue;
                    }
                    beanEncrypt(entry.getValue());
                }
                // bean
            } else {
                beanEncrypt(parameter);
            }
        }

        // 获得出参
        Object returnValue = invocation.proceed();

        // 出参解密
        if (isNotCrypt(returnValue)) {
            return returnValue;
        }
        Boolean bo = getMethodAnnotations(statement);
        if (returnValue instanceof String && bo) {
            return stringDecrypt((String) returnValue);
        }
        if (returnValue instanceof List) {
            listDecrypt((List) returnValue, bo);
            return returnValue;
        }

        return returnValue;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

    /**
     * 获取 方法上的注解
     *
     * @param statement
     * @return
     * @throws ClassNotFoundException
     */
    private Boolean getMethodAnnotations(MappedStatement statement) throws ClassNotFoundException {
        final String id = statement.getId();
        Boolean bo = METHOD_ANNOTATIONS_MAP.get(id);
        if (bo != null) {
            return bo;
        }
        Method m = getMethodByMappedStatementId(id);
        if (m == null) {
            return Boolean.FALSE;
        }
        final CryptField cryptField = m.getAnnotation(CryptField.class);
        // 如果允许解密
        if (cryptField != null && cryptField.decrypt()) {
            bo = Boolean.TRUE;
        } else {
            bo = Boolean.FALSE;
        }
        Boolean bo1 = METHOD_ANNOTATIONS_MAP.putIfAbsent(id, bo);
        if (bo1 != null) {
            bo = bo1;
        }

        return bo;
    }

    /**
     * 获取 方法参数上的注解
     *
     * @param statement
     * @return
     * @throws ClassNotFoundException
     */
    private Set<String> getParameterAnnotations(MappedStatement statement) throws ClassNotFoundException {
        final String id = statement.getId();
        Set<String> set = METHOD_PARAM_ANNOTATIONS_MAP.get(id);
        if (set != null) {
            return set;
        }
        set = new HashSet<>();
        Method m = getMethodByMappedStatementId(id);
        if (m == null) {
            return set;
        }
        final Annotation[][] paramAnnotations = m.getParameterAnnotations();
        // get names from @CryptField annotations
        for (Annotation[] paramAnnotation : paramAnnotations) {
            for (Annotation annotation : paramAnnotation) {
                if (annotation instanceof CryptField) {
                    CryptField cryptField = (CryptField) annotation;
                    // 如果允许加密
                    if (cryptField.encrypt()) {
                        set.add(cryptField.value());
                    }
                    break;
                }
            }
        }

        Set<String> oldSet = METHOD_PARAM_ANNOTATIONS_MAP.putIfAbsent(id, set);
        if (oldSet != null) {
            set = oldSet;
        }

        return set;
    }

    /**
     * 通过mappedStatementId get Method
     *
     * @param id
     * @return
     * @throws ClassNotFoundException
     */
    private Method getMethodByMappedStatementId(String id) throws ClassNotFoundException {
        Method m = null;
        final Class clazz = Class.forName(id.substring(0, id.lastIndexOf(MAPPEDSTATEMENT_ID_SEPERATOR)));
        for (Method method : clazz.getMethods()) {
            if (method.getName().equals(id.substring(id.lastIndexOf(MAPPEDSTATEMENT_ID_SEPERATOR) + 1))) {
                m = method;
                break;
            }
        }

        return m;
    }

    /**
     * 判断是否需要加解密
     *
     * @param o
     * @return
     */
    private boolean isNotCrypt(Object o) {
        return o == null || o instanceof Double || o instanceof Integer || o instanceof Long || o instanceof Boolean;
    }

    /**
     * String 加密
     *
     * @param str
     * @return
     * @throws Exception
     */
    private String stringEncrypt(String str) throws Exception {
        return stringEncrypt(null, str, null, null);
    }

    /**
     * String 加密
     *
     * @param str
     * @param set
     * @return
     * @throws Exception
     */
    private String stringEncrypt(String str, Set<String> set) throws Exception {
        return stringEncrypt(null, str, set, true);
    }

    /**
     * String 加密
     *
     * @param name
     * @param str
     * @param set
     * @return
     * @throws Exception
     */
    private String stringEncrypt(String name, String str, Set<String> set) throws Exception {
        return stringEncrypt(name, str, set, false);
    }

    /**
     * String 加密
     *
     * @param name
     * @param str
     * @param set
     * @param isSingle
     * @return
     * @throws Exception
     */
    private String stringEncrypt(String name, String str, Set<String> set, Boolean isSingle) throws Exception {
        if (StringUtils.isBlank(str)) {
            return str;
        }
        if (isSingle == null) {
            //todo 加密实现
            str = "";
            return str;
        }
        if (isSingle && set != null && !set.isEmpty()) {
            //todo 加密实现
            str = "";
            return str;
        }
        if (!isSingle && set != null && !set.isEmpty() && set.contains(name)) {
            //todo 加密实现
            str = "";
            return str;
        }

        return str;
    }

    /**
     * String 解密
     *
     * @param str
     * @return
     */
    private String stringDecrypt(String str) {
        if (StringUtils.isBlank(str)) {
            return str;
        }
        String[] array = str.split("\\|");
        if (array.length < 2) {
            return str;
        }
        //todo 解密实现
        str = "";

        return str;
    }

    /**
     * list 加密
     *
     * @param list
     * @param bo
     * @return
     * @throws Exception
     */
    private List listEncrypt(List list, Boolean bo) throws Exception {
        for (int i = 0; i < list.size(); i++) {
            Object listValue = list.get(i);
            // 判断不需要解析的类型
            if (isNotCrypt(listValue) || listValue instanceof Map) {
                break;
            }
            if (listValue instanceof String && bo) {
                list.set(i, stringEncrypt((String) listValue));
                continue;
            }
            beanEncrypt(listValue);
        }

        return list;
    }

    /**
     * list 解密
     *
     * @param list
     * @param bo
     * @return
     * @throws Exception
     */
    private List listDecrypt(List list, Boolean bo) throws Exception {
        for (int i = 0; i < list.size(); i++) {
            Object listValue = list.get(i);
            // 判断不需要解析的类型 获得
            if (isNotCrypt(listValue) || listValue instanceof Map) {
                break;
            }
            if (listValue instanceof String && bo) {
                list.set(i, stringDecrypt((String) listValue));
                continue;
            }
            beanDecrypt(listValue);
        }

        return list;
    }

    /**
     * bean 加密
     *
     * @param val
     * @throws Exception
     */
    private void beanEncrypt(Object val) throws Exception {
        Class objClazz = val.getClass();
        Field[] objFields = objClazz.getDeclaredFields();
        for (Field field : objFields) {
            CryptField cryptField = field.getAnnotation(CryptField.class);
            if (cryptField != null && cryptField.encrypt()) {
                field.setAccessible(true);
                Object fieldValue = field.get(val);
                if (fieldValue == null) {
                    continue;
                }
                if (field.getType().equals(String.class)) {
                    field.set(val, stringEncrypt((String) fieldValue));
                    continue;
                }
                if (field.getType().equals(List.class)) {
                    field.set(val, listEncrypt((List) fieldValue, Boolean.TRUE));
                    continue;
                }
            }
        }
    }

    /**
     * bean 解密
     *
     * @param val
     * @throws Exception
     */
    private void beanDecrypt(Object val) throws Exception {
        Class objClazz = val.getClass();
        Field[] objFields = objClazz.getDeclaredFields();
        for (Field field : objFields) {
            CryptField cryptField = field.getAnnotation(CryptField.class);
            if (cryptField != null && cryptField.decrypt()) {
                field.setAccessible(true);
                Object fieldValue = field.get(val);
                if (fieldValue == null) {
                    continue;
                }
                if (field.getType().equals(String.class)) {
                    field.set(val, stringDecrypt((String) fieldValue));
                    continue;
                }
                if (field.getType().equals(List.class)) {
                    field.set(val, listDecrypt((List) fieldValue, Boolean.TRUE));
                    continue;
                }
            }
        }
    }
}