package krpc.rpc.core;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import krpc.rpc.core.proto.RpcMeta;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.lang.reflect.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;

public class ReflectionUtils {

    static Logger log = LoggerFactory.getLogger(ReflectionUtils.class);

    static Class<?>[] dummyTypes = new Class<?>[0];
    static Object[] dummyParameters = new Object[0];
    static Class<?>[] callableTypes = new Class<?>[]{RpcCallable.class};

    public static String retCodeFieldInMap = "retCode"; // can be configured
    public static String retMsgFieldInMap = "retMsg"; // can be configured
    static String retCodeField = "retCode_";
    static String retMsgField = "retMsg_";
    static Map<String, Field> retCodeFields = new HashMap<String, Field>();
    static Map<String, Field> retMsgFields = new HashMap<String, Field>();
    static Field metaSequenceField = null;
    static Field metaCompressField = null;
    static Field metaEncryptField = null;
    static ConcurrentHashMap<String, Object> errors = new ConcurrentHashMap<String, Object>();

    static Field metaTraceField = null;
    static Field metaPeersField = null;
    static Field metaTraceIdField = null;
    static Field metaParentSpanIdField = null;
    static Field metaSpanIdField = null;

    static {
        init();
    }

    public static void init() {
        try {

            metaTraceField = RpcMeta.class.getDeclaredField("trace_");
            metaTraceField.setAccessible(true);
            metaPeersField = RpcMeta.Trace.class.getDeclaredField("peers_");
            metaPeersField.setAccessible(true);
            metaTraceIdField = RpcMeta.Trace.class.getDeclaredField("traceId_");
            metaTraceIdField.setAccessible(true);
            metaParentSpanIdField = RpcMeta.Trace.class.getDeclaredField("parentSpanId_");
            metaParentSpanIdField.setAccessible(true);
            metaSpanIdField = RpcMeta.Trace.class.getDeclaredField("spanId_");
            metaSpanIdField.setAccessible(true);

            metaCompressField = RpcMeta.class.getDeclaredField("compress_");
            metaCompressField.setAccessible(true);
            metaEncryptField = RpcMeta.class.getDeclaredField("encrypt_");
            metaEncryptField.setAccessible(true);
            metaSequenceField = RpcMeta.class.getDeclaredField("sequence_");
            metaSequenceField.setAccessible(true);
            retCodeField = retCodeFieldInMap + "_";
            retMsgField = retMsgFieldInMap + "_";
        } catch (Exception e) {
            throw new RuntimeException("ReflectionUtils init failed");
        }
    }

    public static int getRetCode(Object object) {
        if (object instanceof DynamicMessage) {
            DynamicMessage dm = (DynamicMessage) object;
            FieldDescriptor f = dm.getDescriptorForType().findFieldByName(retCodeFieldInMap);
            if (f == null)
                return -99999999;  // should not be executed
            return (Integer) dm.getField(f);
        }

        try {
            Field f = retCodeFields.get(object.getClass().getName());
            if (f == null) {
                return -99999999;  // should not be executed
            }

            return (Integer) f.get(object);
        } catch (Exception e) {
            log.error("getRetCode exception", e);
            return -99999999; // should not be called
        }
    }

    public static void setRetCode(Object object, int retCode) {
        try {
            Field f = retCodeFields.get(object.getClass().getName());
            if (f == null) {
                throw new RuntimeException("setRetCode exception, no field found");
            }
            f.set(object, retCode);
        } catch (Exception e) {
            throw new RuntimeException("setRetCode exception");
        }
    }

    public static String getRetMsg(Object object) {
        if (object instanceof DynamicMessage) {
            DynamicMessage dm = (DynamicMessage) object;
            FieldDescriptor f = dm.getDescriptorForType().findFieldByName(retMsgFieldInMap);
            if (f == null)
                return "";
            return (String) dm.getField(f);
        }

        try {
            Field f = retMsgFields.get(object.getClass().getName());
            if (f == null) {
                return "";
            }

            Object o = f.get(object);
            if (o instanceof String) {
                return (String) o;
            } else {
                ByteString bs = (ByteString) o;
                return bs.toStringUtf8();
            }
        } catch (Exception e) {
            log.error("getRetMsg exception", e);
            return "";
        }
    }

    public static void setRetMsg(Object object, String retMsg) {
        try {
            Field f = retMsgFields.get(object.getClass().getName());
            if (f == null) {
                return;
            }
            if (retMsg == null) retMsg = "";
            f.set(object, retMsg);
        } catch (Exception e) {
            throw new RuntimeException("setRetMsg exception", e);
        }
    }

    public static Class<?> getClass(String s) {
        try {
            return Class.forName(s);
        } catch (Throwable e) {
            return null;
        }
    }

    public static Object newRefererObject(Class<?> cls, Object callable) {
        try {
            Constructor<?> cons = cls.getDeclaredConstructor(callableTypes);
            return cons.newInstance(new Object[]{callable});
        } catch (Exception e) {
            throw new RuntimeException("newObject exception", e);
        }
    }

    public static Object newObject(String clsName) {
        try {
            Class<?> cls = Class.forName(clsName);
            Constructor<?> cons = cls.getDeclaredConstructor();
            return cons.newInstance(new Object[]{});
        } catch (Throwable e) {
            throw new RuntimeException("newObject exception", e);
        }
    }

    @SuppressWarnings("all")
    public static Object invokeMethod(Object obj, String methodName) {
        try {
            Method method = obj.getClass().getDeclaredMethod(methodName, dummyTypes);
            return method.invoke(obj, dummyParameters);
        } catch (Exception e) {
            log.error("invokeMethod exception, e=" + e.getMessage());
            return null;
        }
    }

    public static Object generateResponseObject(Class<?> cls, int retCode, String retMsg) {
        String key = cls.getName() + ":" + retCode+":"+retMsg;
        Object o = errors.get(key);
        if (o != null) return o;
        o = generateResponseObjectNoCache(cls, retCode, retMsg);
        errors.put(key, o);
        return o;
    }

    @SuppressWarnings("all")
    public static Object generateResponseObjectNoCache(Class<?> cls, int retCode, String retMsg) {

        try {
            if (retCode == 0) {
                Method method = cls.getDeclaredMethod("getDefaultInstance", dummyTypes);
                return method.invoke(null, dummyParameters);
            }

            Method method = cls.getDeclaredMethod("newBuilder", dummyTypes);
            Object builder = method.invoke(null, dummyParameters);
            Method buildMethod = builder.getClass().getDeclaredMethod("build", dummyTypes);
            Object obj = buildMethod.invoke(builder, dummyParameters);

            setRetCode(obj, retCode);
            if (retMsg != null && retMsg.length() > 0) setRetMsg(obj, retMsg);
            return obj;
        } catch (Exception e) {
            throw new RuntimeException("generateResponseObjectNoCache exception", e);
        }
    }

    @SuppressWarnings("all")
    public static Builder generateBuilder(Class<?> cls) {
        try {
            Method method = cls.getDeclaredMethod("newBuilder", dummyTypes);
            Builder builder = (Builder) method.invoke(null, dummyParameters);
            return builder;
        } catch (Exception e) {
            throw new RuntimeException("generateBuilder exception", e);
        }
    }


    public static void adjustPeers(RpcMeta meta, String connId) {
        int p = connId.lastIndexOf(":");
        String addr = connId.substring(0, p);

        if( meta.getTrace() == RpcMeta.Trace.getDefaultInstance() ) { // 默认对象, 否则每次都是修改默认对象,会造成OOM

            String newPeers = addr;

            RpcMeta.Trace trace = RpcMeta.Trace.newBuilder().setPeers(newPeers).build();
            try {
                metaTraceField.set(meta, trace);
            } catch (Exception e) {
                log.error("adjustPeers exception");
            }

        } else {

            String peers = meta.getTrace().getPeers();
            String newPeers = peers.isEmpty() ? addr : peers + "," + addr;

            try {
                metaPeersField.set(meta.getTrace(), newPeers);
            } catch (Exception e) {
                log.error("adjustPeers exception");
            }
        }
    }

    public static void adjustTrace(RpcMeta meta, String traceId, String parentSpanId, String spanId) {
        if( meta.getTrace() == RpcMeta.Trace.getDefaultInstance() ) { // 默认对象

            RpcMeta.Trace trace = RpcMeta.Trace.newBuilder().setTraceId(traceId).setParentSpanId(parentSpanId).setSpanId(spanId).build();

            try {
                metaTraceField.set(meta, trace);
            } catch (Exception e) {
                log.error("adjustTrace exception");
            }

        } else {
            try {
                metaTraceIdField.set(meta.getTrace(), traceId);
                metaParentSpanIdField.set(meta.getTrace(), parentSpanId);
                metaSpanIdField.set(meta.getTrace(), spanId);
            } catch (Exception e) {
                log.error("adjustTrace exception");
            }
        }
    }


    public static void updateSequence(RpcMeta meta, int sequence) {
        try {
            metaSequenceField.set(meta, sequence);
        } catch (Exception e) {
            log.error("adjustPeers exception");
        }
    }

    public static void updateCompress(RpcMeta meta, int zip) {
        try {
            metaCompressField.set(meta, zip);
        } catch (Exception e) {
            log.error("updateCompress exception");
        }
    }

    public static void updateEncrypt(RpcMeta meta, int encrypt) {
        try {
            metaEncryptField.set(meta, encrypt);
        } catch (Exception e) {
            log.error("updateCompress exception");
        }
    }

    public static void checkInterface(Class<?> intf, Object obj) {
        if (intf.isAssignableFrom(obj.getClass())) return;
        throw new RuntimeException("not a valid service object");
    }

    public static void checkInterface(String intfName, Object obj) {
        try {
            Class<?> intf = Class.forName(intfName);
            if (intf.isAssignableFrom(obj.getClass())) return;
            throw new RuntimeException("not a valid service object");
        } catch (Throwable e) {
            throw new RuntimeException("interface not found, cls=" + intfName);
        }
    }

    public static int getServiceId(Class<?> intf) {
        try {
            Field field = intf.getDeclaredField("serviceId");

            if (Modifier.isStatic(field.getModifiers())) {
                int serviceId = (Integer) field.get(null);
                return serviceId;
            }
            throw new RuntimeException("interface_parse_serviceId_exception");
        } catch (Exception e) {
            throw new RuntimeException("interface_parse_serviceId_exception");
        }
    }

    public static HashMap<Integer, String> getMsgIds(Class<?> intf) {
        try {
            Field[] declaredFields = intf.getDeclaredFields();
            HashMap<Integer, String> msgIds = new HashMap<Integer, String>();
            for (Field field : declaredFields) {
                if (Modifier.isStatic(field.getModifiers())) {
                    if (field.getName().endsWith("MsgId")) {
                        int msgId = field.getInt(null);
                        msgIds.put(msgId, field.getName().substring(0, field.getName().length() - 5));
                    }
                }
            }
            return msgIds;
        } catch (Exception e) {
            throw new RuntimeException("interface_parse_msgId_exception");
        }
    }

    public static HashMap<String, Integer> getMsgNames(Class<?> intf) {
        try {
            Field[] declaredFields = intf.getDeclaredFields();
            HashMap<String, Integer> msgNames = new HashMap<String, Integer>();
            for (Field field : declaredFields) {
                if (Modifier.isStatic(field.getModifiers())) {
                    if (field.getName().endsWith("MsgId")) {
                        int msgId = field.getInt(null);
                        msgNames.put(field.getName().substring(0, field.getName().length() - 5), msgId);
                    }
                }
            }
            return msgNames;
        } catch (Exception e) {
            throw new RuntimeException("interface_parse_msgId_exception");
        }
    }

    public static void checkReqResSame(Class<?> intf) {
        Method[] methods = intf.getDeclaredMethods();
        for (Method m : methods) {
            if (Modifier.isStatic(m.getModifiers())) continue;
            if (m.getParameterCount() != 1) continue;
            Class<?> reqCls = m.getParameterTypes()[0];
            Class<?> resCls = m.getReturnType();
            if (!Message.class.isAssignableFrom(reqCls)) continue;
            if (!Message.class.isAssignableFrom(resCls)) continue;

            if( reqCls == resCls ) {
                throw new RuntimeException("method define error in proto file, method="+m.getName()+", res cls == req cls");
            }
        }
    }

    public static HashMap<String, Object> getMethodInfo(Class<?> intf) {
        try {
            Method[] methods = intf.getDeclaredMethods();
            HashMap<String, Object> msgNames = new HashMap<String, Object>();
            for (Method m : methods) {
                if (Modifier.isStatic(m.getModifiers())) continue;
                if (m.getParameterCount() != 1) continue;
                Class<?> reqCls = m.getParameterTypes()[0];
                Class<?> resCls = m.getReturnType();
                if (!Message.class.isAssignableFrom(reqCls)) continue;
                if (!Message.class.isAssignableFrom(resCls)) continue;
                msgNames.put(m.getName(), m);
                msgNames.put(m.getName() + "-req", reqCls);
                msgNames.put(m.getName() + "-res", resCls);
                msgNames.put(m.getName() + "-reqp", reqCls.getDeclaredMethod("parseFrom", InputStream.class));
                msgNames.put(m.getName() + "-resp", resCls.getDeclaredMethod("parseFrom", InputStream.class));

                Field f1 = resCls.getDeclaredField(retCodeField);
                f1.setAccessible(true);
                retCodeFields.put(resCls.getName(), f1);
                try {
                    Field f2 = resCls.getDeclaredField(retMsgField);
                    f2.setAccessible(true);
                    retMsgFields.put(resCls.getName(), f2);
                } catch (Exception e) {
                }
            }
            return msgNames;
        } catch (Exception e) {
            throw new RuntimeException("getMethodInfo intf="+intf.getName(), e);
        }
    }

    public static HashMap<String, Method> getParsers(Class<?> reqCls, Class<?> resCls) {
        try {
            HashMap<String, Method> map = new HashMap<String, Method>();
            map.put("reqp", reqCls.getDeclaredMethod("parseFrom", InputStream.class));
            map.put("resp", resCls.getDeclaredMethod("parseFrom", InputStream.class));

            Field f1 = resCls.getDeclaredField(retCodeField);
            f1.setAccessible(true);
            retCodeFields.put(resCls.getName(), f1);
            try {
                Field f2 = resCls.getDeclaredField(retMsgField);
                f2.setAccessible(true);
                retMsgFields.put(resCls.getName(), f2);
            } catch (Exception e) {
            }
            return map;
        } catch (Exception e) {
            throw new RuntimeException("getMethodInfo reqCls="+reqCls+",resCls="+resCls, e);
        }
    }

    public static HashMap<String, Object> getAsyncMethodInfo(Class<?> intf) {
        try {
            Method[] methods = intf.getDeclaredMethods();
            HashMap<String, Object> msgNames = new HashMap<String, Object>();
            for (Method m : methods) {
                if (Modifier.isStatic(m.getModifiers())) continue;
                if (m.getParameterCount() != 1) continue;
                Class<?> reqCls = m.getParameterTypes()[0];
                Class<?> resFutureCls = m.getReturnType();
                if (!Message.class.isAssignableFrom(reqCls)) continue;
                if (!CompletableFuture.class.isAssignableFrom(resFutureCls)) continue;
                Class<?> resCls = parseParameterCls(m.getGenericReturnType());
                if (!Message.class.isAssignableFrom(resCls)) continue;

                msgNames.put(m.getName(), m);
                msgNames.put(m.getName() + "-req", reqCls);
                msgNames.put(m.getName() + "-res", resCls);
                msgNames.put(m.getName() + "-reqp", reqCls.getDeclaredMethod("parseFrom", InputStream.class));
                msgNames.put(m.getName() + "-resp", resCls.getDeclaredMethod("parseFrom", InputStream.class));

                Field f1 = resCls.getDeclaredField(retCodeField);
                f1.setAccessible(true);
                retCodeFields.put(resCls.getName(), f1);
                try {
                    Field f2 = resCls.getDeclaredField(retMsgField);
                    f2.setAccessible(true);
                    retMsgFields.put(resCls.getName(), f2);
                } catch (Exception e) {
                }
            }
            return msgNames;
        } catch (Exception e) {
            throw new RuntimeException("getAsyncMethodInfo", e);
        }
    }

    private static Class<?> parseParameterCls(Type t) throws ClassNotFoundException {
        String s = t.toString();
        int p1 = s.indexOf("<");
        int p2 = s.indexOf(">");
        String clsName = s.substring(p1 + 1, p2);
        return Class.forName(clsName);
    }

    public static DynamicMessage.Builder generateDynamicBuilder(Descriptor desc) {
        return DynamicMessage.newBuilder(desc);
    }

    public static Object generateResponseObject(DynamicMessage.Builder b, String cacheName, int retCode, String retMsg) {
        String key = cacheName + ":" + retCode+":"+retMsg;
        Object o = errors.get(key);
        if (o != null) return o;
        o = generateResponseObjectNoCache(b, retCode, retMsg);
        errors.put(key, o);
        return o;
    }

    @SuppressWarnings("all")
    public static Object generateResponseObjectNoCache(DynamicMessage.Builder b, int retCode, String retMsg) {
        try {
            for (FieldDescriptor field : b.getDescriptorForType().getFields()) {
                if (field.getName().equals(retCodeFieldInMap))
                    b.setField(field, retCode);
                if (field.getName().equals(retMsgFieldInMap))
                    b.setField(field, retMsg);
            }
            return b.build();
        } catch (Exception e) {
            throw new RuntimeException("generateResponseObjectNoCache exception", e);
        }
    }

}