package com.mogujie.trade.tsharding.route.orm;

import com.mogujie.trade.tsharding.client.ShardingCaculator;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.bytecode.ClassFile;
import javassist.bytecode.ConstPool;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.factory.ObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory;
import org.apache.ibatis.session.Configuration;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 通用Mapper增强基类,扩展Mapper sql时需要继承该类
 *
 * @author qigong on 5/1/15
 */
public abstract class MapperEnhancer {

    private static ClassPool pool = ClassPool.getDefault();

    private Map<String, Method> methodMap = new HashMap<String, Method>();
    private Class<?> mapperClass;

    public MapperEnhancer(Class<?> mapperClass) {
        this.mapperClass = mapperClass;
    }

    /**
     * 代码增加方法标记
     *
     * @param record
     */
    public String enhancedShardingSQL(Object record) {
        return "enhancedShardingSQL";
    }

    public MapperEnhancer() {
        super();
    }

    /**
     * 对mapper进行增强,生成新的mapper,并主动加载新mapper类到classloader
     *
     * @param mapperClassName
     */
    public static void enhanceMapperClass(String mapperClassName) throws Exception {

        Class originClass = Class.forName(mapperClassName);
        Method[] originMethods = originClass.getDeclaredMethods();

        CtClass cc = pool.get(mapperClassName);

        for (CtMethod ctMethod : cc.getDeclaredMethods()) {
            CtClass enhanceClass = pool.makeInterface(mapperClassName + "Sharding" + ctMethod.getName());
            for (long i = 0L; i < 512; i++) {
                CtMethod newMethod = new CtMethod(ctMethod.getReturnType(), ctMethod.getName() + ShardingCaculator.getNumberWithZeroSuffix(i), ctMethod.getParameterTypes(), enhanceClass);

                Method method = getOriginMethod(newMethod, originMethods);
                if(method.getParameterAnnotations()[0].length > 0) {
                    ClassFile ccFile = enhanceClass.getClassFile();
                    ConstPool constPool = ccFile.getConstPool();

                    //拷贝注解信息和注解内容,以支持mybatis mapper类的动态绑定
                    newMethod.getMethodInfo().addAttribute(MapperAnnotationEnhancer.duplicateParameterAnnotationsAttribute(constPool, method));
                }
                enhanceClass.addMethod(newMethod);
            }
            Class<?> loadThisClass = enhanceClass.toClass();

            //2015.09.22后不再输出类到本地
//            enhanceClass.writeFile(".");
        }
    }

    private static Method getOriginMethod(CtMethod ctMethod, Method[] originMethods) {
        for (Method method : originMethods) {
            int len = ctMethod.getName().length();
            if (ctMethod.getName().substring(0, len-4).equals(method.getName())) {
                return method;
            }
        }
        throw new RuntimeException("enhanceMapperClass find method error!");
    }

    /**
     * 添加映射方法
     *
     * @param methodName
     * @param method
     */
    public void addMethodMap(String methodName, Method method) {
        methodMap.put(methodName, method);
    }


    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();

    /**
     * 反射对象,增加对低版本Mybatis的支持
     *
     * @param object 反射对象
     * @return
     */
    public static MetaObject forObject(Object object) {
        return MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY);
    }

    /**
     * 是否支持该通用方法
     *
     * @param msId
     * @return
     */
    public boolean supportMethod(String msId) {
        Class<?> mapperClass = getMapperClass(msId);
        if (this.mapperClass.isAssignableFrom(mapperClass)) {
            String methodName = getMethodName(msId);
            return methodMap.get(methodName) != null;
        }
        return false;
    }

    /**
     * 重新设置SqlSource
     *
     * @param ms
     * @param sqlSource
     */
    protected void setSqlSource(MappedStatement ms, SqlSource sqlSource) {
        MetaObject msObject = forObject(ms);
        msObject.setValue("sqlSource", sqlSource);
    }

    /**
     * 重新设置SqlSource
     *
     * @param ms
     * @throws java.lang.reflect.InvocationTargetException
     * @throws IllegalAccessException
     */
    public void setSqlSource(MappedStatement ms, Configuration configuration) throws Exception {
        Method method = methodMap.get(getMethodName(ms));
        try {
            if (method.getReturnType() == Void.TYPE) {
                method.invoke(this, ms);
            } else if (SqlSource.class.isAssignableFrom(method.getReturnType())) {
                //代码增强 扩充为512个方法。
                for (long i = 0; i < 512; i++) {

                    //新的带sharding的sql
                    SqlSource sqlSource = (SqlSource) method.invoke(this, ms, configuration, i);

                    String newMsId = ms.getId() + ShardingCaculator.getNumberWithZeroSuffix(i);
                    newMsId = newMsId.replace("Mapper.", "MapperSharding" + getMethodName(ms) + ".");

                    //添加到ms库中
                    MappedStatement newMs = copyFromMappedStatement(ms, sqlSource, newMsId);
                    configuration.addMappedStatement(newMs);
                    setSqlSource(newMs, sqlSource);
                }
            } else {
                throw new RuntimeException("自定义Mapper方法返回类型错误,可选的返回类型为void和SqlNode!");
            }
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        } catch (InvocationTargetException e) {
            throw new RuntimeException(e.getTargetException() != null ? e.getTargetException() : e);
        }
    }

    protected MappedStatement copyFromMappedStatement(MappedStatement ms,
                                                      SqlSource newSqlSource, String newMsId) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), newMsId, newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        // setStatementTimeout()
        builder.timeout(ms.getTimeout());
        // setParameterMap()
        builder.parameterMap(ms.getParameterMap());
        // setStatementResultMap()
        List<ResultMap> resultMaps = ms.getResultMaps();
        builder.resultMaps(resultMaps);
        builder.resultSetType(ms.getResultSetType());
        // setStatementCache()
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    /**
     * 根据msId获取接口类
     *
     * @param msId
     * @return
     * @throws ClassNotFoundException
     */
    public static Class<?> getMapperClass(String msId) {
        String mapperClassStr = msId.substring(0, msId.lastIndexOf("."));
        try {
            return Class.forName(mapperClassStr);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("无法获取Mapper接口信息:" + msId);
        }
    }

    /**
     * 获取执行的方法名
     *
     * @param ms
     * @return
     */
    public static String getMethodName(MappedStatement ms) {
        return getMethodName(ms.getId());
    }

    /**
     * 获取执行的方法名
     *
     * @param msId
     * @return
     */
    public static String getMethodName(String msId) {
        return msId.substring(msId.lastIndexOf(".") + 1);
    }
}