package com.cyfonly.thriftj.loadbalance;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import javassist.util.proxy.MethodFilter;
import javassist.util.proxy.MethodHandler;
import javassist.util.proxy.Proxy;
import javassist.util.proxy.ProxyFactory;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig;
import org.apache.thrift.TServiceClient;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;

import com.cyfonly.thriftj.constants.Constant;
import com.cyfonly.thriftj.exceptions.NoServerAvailableException;
import com.cyfonly.thriftj.exceptions.ValidationException;
import com.cyfonly.thriftj.failover.ConnectionValidator;
import com.cyfonly.thriftj.failover.FailoverChecker;
import com.cyfonly.thriftj.failover.FailoverStrategy;
import com.cyfonly.thriftj.pool.DefaultThriftConnectionPool;
import com.cyfonly.thriftj.pool.ThriftConnectionFactory;
import com.cyfonly.thriftj.pool.ThriftServer;
import com.cyfonly.thriftj.utils.MurmurHash3;
import com.cyfonly.thriftj.utils.ThriftClientUtil;
import com.google.common.base.Charsets;


/**
 * 基于 load balance 的 Client 选择器
 * @author yunfeng.cheng
 * @create 2016-11-21
 */
public class ClientSelector {
	
	private FailoverChecker failoverChecker;
    private DefaultThriftConnectionPool poolProvider;
    private int loadBalance;
    
    private AtomicInteger i = new AtomicInteger(0);
    
    @SuppressWarnings({ "rawtypes", "unchecked" })
	public ClientSelector(String servers, int loadBalance, ConnectionValidator validator, GenericKeyedObjectPoolConfig poolConfig, FailoverStrategy strategy, int connTimeout, String backupServers, int serviceLevel) {
    	this.failoverChecker = new FailoverChecker(validator, strategy, serviceLevel);
    	this.poolProvider = new DefaultThriftConnectionPool(new ThriftConnectionFactory(failoverChecker, connTimeout), poolConfig);
    	failoverChecker.setConnectionPool(poolProvider);
    	failoverChecker.setServerList(ThriftServer.parse(servers));
        if (StringUtils.isNotEmpty(backupServers)) {
        	failoverChecker.setBackupServerList(ThriftServer.parse(backupServers));
        } else{
        	failoverChecker.setBackupServerList(new ArrayList<ThriftServer>());
        }
        failoverChecker.startChecking();
    }
	
	public <X extends TServiceClient> X iface(Class<X> ifaceClass) {
		if (this.loadBalance == Constant.LoadBalance.HASH) {
			throw new ValidationException("Can not use HASH without a key.");
		}
		
		switch (this.loadBalance) {
		case Constant.LoadBalance.RANDOM:
			return getRandomClient(ifaceClass);
		case Constant.LoadBalance.ROUND_ROBIN:
			return getRRClient(ifaceClass);
		case Constant.LoadBalance.WEIGHT:
			return getWeightClient(ifaceClass);
		default:
			return getRandomClient(ifaceClass);
		}
    }
	
	public <X extends TServiceClient> X iface(Class<X> ifaceClass, String key) {
		if (this.loadBalance != Constant.LoadBalance.HASH) {
			throw new ValidationException("Must use other load balance strategy.");
		}
		
		return getHashIface(ifaceClass, key);
    }
	
	protected <X extends TServiceClient> X getRandomClient(Class<X> ifaceClass) {
		return iface(ifaceClass, ThriftClientUtil.randomNextInt());
	}
	
	protected <X extends TServiceClient> X getRRClient(Class<X> ifaceClass) {
		return iface(ifaceClass, i.getAndDecrement());
	}
	
	protected <X extends TServiceClient> X getWeightClient(Class<X> ifaceClass) {
		List<ThriftServer> servers = getAvaliableServers();
		if (servers == null || servers.isEmpty()) {
			throw new NoServerAvailableException("No server available.");
		}
		int[] weights = new int[servers.size()];
		for (int i = 0; i < servers.size(); i++) {
			weights[i] = servers.get(i).getWeight();
		}
		return iface(ifaceClass, servers.get(ThriftClientUtil.chooseWithWeight(weights)));
	}
	
	protected <X extends TServiceClient> X getHashIface(Class<X> ifaceClass, String key) {
		byte[] bytes = key.getBytes(Charsets.UTF_8);
        return iface(ifaceClass, MurmurHash3.murmurhash3_x86_32(bytes, 0, bytes.length, 0x1234ABCD));
	}
	
	protected <X extends TServiceClient> X iface(Class<X> ifaceClass, int index) {
        List<ThriftServer> serverList = getAvaliableServers();
        if (serverList == null || serverList.isEmpty()) {
            throw new NoServerAvailableException("No server available.");
        }
        index = Math.abs(index);
        final ThriftServer selected = serverList.get(index % serverList.size());
        return iface(ifaceClass, selected);
    }
	
	@SuppressWarnings("unchecked")
	protected <X extends TServiceClient> X iface(final Class<X> ifaceClass, final ThriftServer selected) {
        final TTransport transport;
        try {
            transport = poolProvider.getConnection(selected);
        } catch (RuntimeException e) {
            if (e.getCause() != null && e.getCause() instanceof TTransportException) {
            	failoverChecker.getFailoverStrategy().fail(selected);
            }
            throw e;
        }
        TProtocol protocol = new TBinaryProtocol(transport);

        ProxyFactory factory = new ProxyFactory();
        factory.setSuperclass(ifaceClass);
        factory.setFilter(new MethodFilter() {
            @Override
            public boolean isHandled(Method m) {
                return ThriftClientUtil.getInterfaceMethodNames(ifaceClass).contains(m.getName());
            }
        });
        try {
            X x = (X) factory.create(new Class[]{TProtocol.class}, new Object[]{protocol});
            ((Proxy) x).setHandler(new MethodHandler() {
                @Override
                public Object invoke(Object self, Method thisMethod, Method proceed, Object[] args) throws Throwable {
                    boolean success = false;
                    try {
                        Object result = proceed.invoke(self, args);
                        success = true;
                        return result;
                    } finally {
                        if (success) {
                            poolProvider.returnConnection(selected, transport);
                        } else {
                            failoverChecker.getFailoverStrategy().fail(selected);
                            poolProvider.returnBrokenConnection(selected, transport);
                        }
                    }
                }
            });
            return x;
        } catch (Exception e) {
            throw new RuntimeException("Fail to create proxy.", e);
        }
    }
	
	public List<ThriftServer> getAvaliableServers() {
        return failoverChecker.getAvailableServers();
    }
	
	public void close(){
		failoverChecker.stopChecking();
        poolProvider.close();
	}
}