/** * */ package io.client.thrift; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.Socket; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import javax.net.SocketFactory; /** * @author HouKangxi * */ public class ClientInterfaceFactory { private ClientInterfaceFactory() { } private static ConcurrentHashMap<Long, Object> ifaceCache = new ConcurrentHashMap<Long, Object>(); /** * 获得与服务端通信的接口对象 * <p> * 调用者可以实现自定义的 * SocketFactory来内部配置Socket参数(如超时时间,SSL等),也可以通过返回包装的Socket来实现连接池, * 也可以使用内置的连接池类:{@link io.client.thrift.pool.SocketConnectionPool} <br/> * 使用例子:<br/> * {@code SocketFactory tcpfac = new TcpSocketFactory("localhost", 8080);} * <br/> * {@code SocketFactory pool = new SocketConnectionPool(tcpfac);} <br/> * {@code SomeIface service = ClientInterfaceFactory.getClientInterface(SomeIface.class, pool); } * <br/> * * * @param ifaceClass * - 接口class * @param factory * - 套接字工厂类, 注意:需要实现 createSocket() 方法,需要实现hashCode()方法来区分factory * @return 接口对象 */ @SuppressWarnings("unchecked") public static <INTERFACE> INTERFACE getClientInterface(Class<INTERFACE> ifaceClass, SocketFactory factory) { long part1 = ifaceClass.getName().hashCode(); final Long KEY = (part1 << 32) | factory.hashCode(); INTERFACE iface = (INTERFACE) ifaceCache.get(KEY); if (iface == null) { iface = (INTERFACE) Proxy.newProxyInstance(ifaceClass.getClassLoader(), new Class[] { ifaceClass }, new Handler(factory)); ifaceCache.putIfAbsent(KEY, iface); } return iface; } private static class Handler implements InvocationHandler { final AtomicInteger seqIdHolder = new AtomicInteger(0); final SocketFactory factory; public Handler(SocketFactory factory) { this.factory = factory; } public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { String methodName = method.getName(); if (args == null || args.length == 0) { if (methodName.equals("toString")) { return Handler.class.getName() + "@" + System.identityHashCode(this); } if (methodName.equals("hashCode")) { return System.identityHashCode(this); } } int seqId = seqIdHolder.incrementAndGet(); ByteArrayOutputStream outbuff = new ByteArrayOutputStream(); TCompactProtocol protocol = new TCompactProtocol(outbuff, null); ProtocolIOUtil.write(methodName, seqId, protocol, method.getGenericParameterTypes(), args); Socket connection = null; Object rs = null; boolean success = true; try { byte[] frame; { byte[] arrContent = outbuff.toByteArray(); final int msgLen = arrContent.length; // System.out.printf("*** 客户端 msgLen = %d, time=%d, // connection = %s\n", msgLen, System.currentTimeMillis(), // connection); frame = new byte[4 + msgLen];// 前四个字节代表消息长度 frame[0] = (byte) (msgLen >> 24); frame[1] = (byte) ((msgLen >> 16) & 0xff); frame[2] = (byte) ((msgLen >> 8) & 0xff); frame[3] = (byte) (msgLen & 0xff); // System.out.printf("** arrayLen = [%d, %d, %d, %d]\n", // arr4Req[0], arr4Req[1], arr4Req[2], arr4Req[3]); System.arraycopy(arrContent, 0, frame, 4, msgLen); } connection = factory.createSocket(); OutputStream out = connection.getOutputStream(); out.write(frame); out.flush(); InputStream in = connection.getInputStream(); if (in != null) { // int readLen = 0, offset = 0; // while (readLen < 4) { // readLen += in.read(arrLen, offset, 4 - readLen); // } int readLen = in.read(frame, 0, 4); if (readLen == 1) { readLen = in.read(frame, 1, 4); // System.out.printf("** respArrayLen(!1) = [%d, %d, %d, // %d]\n", arr4Req[1], arr4Req[2], // arr4Req[3], arr4Req[4]); } /* * else if (readLen == 4) { System.out.printf( * "** respArrayLen = [%d, %d, %d, %d]\n", arr4Req[0], * arr4Req[1], arr4Req[2], arr4Req[3]); } */ // System.out.println("readLen=" + readLen + ",connection = // " + connection); if (readLen == 4) { // 此时arrLen代表返回结果的长度 protocol.transIn = in; rs = ProtocolIOUtil.read(protocol, method.getGenericReturnType(), method.getExceptionTypes(), seqId); } /* * else { System.out.println("arr[0]=" + arr4Req[0] + * ", 出错的socket: " + connection); } */ } } catch (IOException ex) { success = false; throw ex; } catch (Throwable ex) { success = false; throw ex; } finally { if (connection != null) { if (success) { // 正常情况,通过socket.close()关闭,方便切换到定制业务 connection.close(); } else { // 异常情况,直接通过IO流关闭 try { connection.getOutputStream().close(); connection.getInputStream().close(); } catch (Throwable e) { } } } } return rs; } } }