package com.github.chenlei2.springboot.mybatis.rw.starter.pulgin; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Proxy; import java.sql.Connection; import java.util.Properties; import com.github.chenlei2.springboot.mybatis.rw.starter.datasource.ConnectionHold; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.ibatis.executor.statement.RoutingStatementHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.logging.jdbc.ConnectionLogger; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.DefaultReflectorFactory; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.factory.DefaultObjectFactory; import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory; import org.springframework.jdbc.datasource.ConnectionProxy; /** * 数据源读写分离路由 * * @author chenlei * */ @Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}) }) public class RWPlugin implements Interceptor { public static final Log LOG = LogFactory.getLog(RWPlugin.class); public Object intercept(Invocation invocation) throws Throwable { Connection conn = (Connection) invocation.getArgs()[0]; conn = unwrapConnection(conn); if (conn instanceof ConnectionProxy) { //强制走写库 if(ConnectionHold.FORCE_WRITE.get() != null && ConnectionHold.FORCE_WRITE.get()){ if (LOG.isDebugEnabled()) { LOG.debug("本事务强制走写库, 数据库url"+conn.getMetaData().getURL()+", 数据库name" + conn.getMetaData().getDatabaseProductName()); } routeConnection(ConnectionHold.WRITE, conn); return invocation.proceed(); } StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); MetaObject metaObject = MetaObject.forObject(statementHandler, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory()); MappedStatement mappedStatement = null; if (statementHandler instanceof RoutingStatementHandler) { mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); } else { mappedStatement = (MappedStatement) metaObject.getValue("mappedStatement"); } String key = ConnectionHold.WRITE; String sel = statementHandler.getBoundSql().getSql().trim().substring(0,3); if (sel.equalsIgnoreCase("sel") && !mappedStatement.getId().endsWith(".insert!selectKey")) { key = ConnectionHold.READ; if (LOG.isDebugEnabled()) { LOG.debug("当前数据库为读库, 数据库url"+conn.getMetaData().getURL()+", 数据库name" + conn.getMetaData().getDatabaseProductName()); } } else { if (LOG.isDebugEnabled()) { LOG.debug("当前数据库为写库, 数据库url"+conn.getMetaData().getURL()+", 数据库name" + conn.getMetaData().getDatabaseProductName()); } } routeConnection(key, conn); } return invocation.proceed(); } private void routeConnection(String key, Connection conn) { ConnectionHold.CURRENT_CONNECTION.set(key); // 同一个线程下保证最多只有一个写数据链接和读数据链接 if (!ConnectionHold.CONNECTION_CONTEXT.get().containsKey(key)) { ConnectionProxy conToUse = (ConnectionProxy) conn; conn = conToUse.getTargetConnection(); ConnectionHold.CONNECTION_CONTEXT.get().put(key, conn); } } public Object plugin(Object target) { if (target instanceof StatementHandler) { return Plugin.wrap(target, this); } else { return target; } } public void setProperties(Properties properties) { // NOOP } /** * MyBatis wraps the JDBC Connection with a logging proxy but Spring registers the original connection so it should * be unwrapped before calling {@code DataSourceUtils.isConnectionTransactional(Connection, DataSource)} * * @param connection May be a {@code ConnectionLogger} proxy * @return the original JDBC {@code Connection} */ private Connection unwrapConnection(Connection connection) { if (Proxy.isProxyClass(connection.getClass())) { InvocationHandler handler = Proxy.getInvocationHandler(connection); if (handler instanceof ConnectionLogger) { return ((ConnectionLogger) handler).getConnection(); } } return connection; } }