package com.dianping.zebra.shard.jdbc; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.sql.DataSource; import com.alibaba.druid.sql.ast.SQLObjectImpl; import org.junit.Test; import com.dianping.zebra.shard.jdbc.base.MultiDBBaseTestCase; import com.dianping.zebra.shard.parser.SQLParsedResult; import com.dianping.zebra.shard.parser.SQLParser; import com.dianping.zebra.shard.parser.ShardLimitSqlSplitRewrite; import com.dianping.zebra.shard.parser.SqlToCountSqlRewrite; import junit.framework.Assert; public class MultiDBPreparedStatementLifeCycleLimitTest extends MultiDBBaseTestCase { @Override protected String getDBBaseUrl() { return "jdbc:h2:mem:"; } @Override protected String getCreateScriptConfigFile() { return "db-datafiles/createtable-multidb-lifecycle.xml"; } @Override protected String getDataFile() { return "db-datafiles/data-limitdb-lifecycle.xml"; } @Override protected String[] getSpringConfigLocations() { return new String[] { "ctx-multidb-lifecycle.xml" }; } @Test public void testPopResult() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select * from test limit 1"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(1, popResult.size()); } catch (Exception e) { e.printStackTrace(); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test // 测试改写后offset为0 public void testMultiRouterLimitResult0() throws Exception { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select score from test order by score limit 5,4"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<Integer> rows = new ArrayList<Integer>(); while (rs.next()) { rows.add(rs.getInt("score")); } Assert.assertEquals(4, rows.size()); Assert.assertEquals(3, rows.get(0).intValue()); Assert.assertEquals(3, rows.get(1).intValue()); Assert.assertEquals(3, rows.get(2).intValue()); Assert.assertEquals(3, rows.get(3).intValue()); } catch (Exception e) { e.printStackTrace(); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test // 测试倒序排改写后offset为0 public void testMultiRouterLimitResult1() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select score from test order by score desc limit 7,3"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<Integer> rows = new ArrayList<Integer>(); while (rs.next()) { rows.add(rs.getInt("score")); } Assert.assertEquals(3, rows.size()); Assert.assertEquals(9, rows.get(0).intValue()); Assert.assertEquals(9, rows.get(1).intValue()); Assert.assertEquals(9, rows.get(2).intValue()); } catch (Exception e) { System.err.println(e); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test // 测试改写后offset不为0 public void testMultiRouterLimitResult2() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("SELECT * FROM test ORDER BY score LIMIT 10,3"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<Integer> rows = new ArrayList<Integer>(); while (rs.next()) { rows.add(rs.getInt("score")); } Assert.assertEquals(3, rows.size()); Assert.assertEquals(3, rows.get(0).intValue()); Assert.assertEquals(4, rows.get(1).intValue()); Assert.assertEquals(5, rows.get(2).intValue()); } catch (Exception e) { System.err.println(e); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test // 测试正排,倒排且部分表第一次结果为0的case public void testMultiRouterLimitResult3() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn .prepareStatement("SELECT * FROM test WHERE score > 6 ORDER BY score DESC, NAME ASC LIMIT 5,3"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(3, popResult.size()); Assert.assertEquals("conan2", popResult.get(0).getName()); Assert.assertEquals(10, popResult.get(0).getScore()); Assert.assertEquals("conan4", popResult.get(1).getName()); Assert.assertEquals(10, popResult.get(1).getScore()); Assert.assertEquals("conan1", popResult.get(2).getName()); Assert.assertEquals(9, popResult.get(2).getScore()); } catch (Exception e) { System.err.println(e); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test // 测试正排,倒排 public void testMultiRouterLimitResult4() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn .prepareStatement("SELECT * FROM test WHERE score <10 ORDER BY score DESC, NAME ASC LIMIT 5,3"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(3, popResult.size()); Assert.assertEquals("conan2", popResult.get(0).getName()); Assert.assertEquals(8, popResult.get(0).getScore()); Assert.assertEquals("conan3", popResult.get(1).getName()); Assert.assertEquals(8, popResult.get(1).getScore()); Assert.assertEquals("conan0", popResult.get(2).getName()); Assert.assertEquals(7, popResult.get(2).getScore()); } catch (Exception e) { System.err.println(e); Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test public void testSmallLimit() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select score from test order by score asc limit 3,4"); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(4, popResult.size()); Assert.assertEquals(2, popResult.get(0).getScore()); Assert.assertEquals(2, popResult.get(1).getScore()); Assert.assertEquals(3, popResult.get(2).getScore()); Assert.assertEquals(3, popResult.get(2).getScore()); } catch (Exception e) { Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test public void testLimitSqlWithParams() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn .prepareStatement("select score from test where score > ? order by score asc limit ?,?"); stmt.setInt(1, 8); stmt.setInt(2, 1); stmt.setInt(3, 3); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(3, popResult.size()); Assert.assertEquals(9, popResult.get(0).getScore()); Assert.assertEquals(9, popResult.get(1).getScore()); Assert.assertEquals(10, popResult.get(2).getScore()); } catch (Exception e) { Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test public void testLimitSqlWithLargeParams() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select id, score from test order by score asc limit ?,?"); stmt.setInt(1, 10); stmt.setInt(2, 3); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(3, popResult.size()); Assert.assertEquals(3, popResult.get(0).getScore()); Assert.assertEquals(4, popResult.get(1).getScore()); Assert.assertEquals(5, popResult.get(2).getScore()); } catch (Exception e) { Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test public void testSingleRuleLimit() throws SQLException { DataSource ds = (DataSource) context.getBean("zebraDS"); Connection conn = null; try { conn = ds.getConnection(); PreparedStatement stmt = conn.prepareStatement("select * from test where id = 0 order by id limit ?, ?"); stmt.setInt(1, 1); stmt.setInt(2, 3); stmt.execute(); ResultSet rs = stmt.getResultSet(); List<TestEntity> popResult = popResult(rs); Assert.assertEquals(0, popResult.size()); } catch (Exception e) { Assert.fail(); } finally { if (conn != null) { conn.close(); } } } @Test public void testSplitLimitSqlReWriteFunction() { String sql = "select * from test limit 10,3"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); String rewriteSql = new ShardLimitSqlSplitRewrite().rewrite(parseWithoutCache, 8, null); Assert.assertEquals("SELECT *\nFROM test\nLIMIT 1, 3", rewriteSql); } @Test public void testSplitLimitSqlReWriteFunction2() { String sql = "select * from test limit 3 offset 10"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); String rewriteSql = new ShardLimitSqlSplitRewrite().rewrite(parseWithoutCache, 8, null); Assert.assertEquals("SELECT *\nFROM test\nLIMIT 1, 3", rewriteSql); } @Test public void testCount() { String sql = "select count(test.score) from test order by test.score"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); Map<String, SQLObjectImpl> selectItemMap = parseWithoutCache.getMergeContext().getSelectItemMap(); Map<String, String> columnNameAliasMapping = parseWithoutCache.getMergeContext().getColumnNameAliasMapping(); Assert.assertEquals(true, selectItemMap.containsKey("COUNT(test.score)")); Assert.assertEquals(true, columnNameAliasMapping.isEmpty()); } @Test public void testSelectAll() { String sql = "select test.* from test order by test.score"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); Map<String, SQLObjectImpl> selectItemMap = parseWithoutCache.getMergeContext().getSelectItemMap(); Map<String, String> columnNameAliasMapping = parseWithoutCache.getMergeContext().getColumnNameAliasMapping(); Assert.assertEquals("test.*", selectItemMap.get("*").toString()); Assert.assertEquals(true, columnNameAliasMapping.isEmpty()); } @Test public void testOrderbyFullName() { String sql = "select test.name as n, test.score as s from test order by test.score"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); Map<String, SQLObjectImpl> selectItemMap = parseWithoutCache.getMergeContext().getSelectItemMap(); Map<String, String> columnNameAliasMapping = parseWithoutCache.getMergeContext().getColumnNameAliasMapping(); Assert.assertEquals(true, selectItemMap.containsKey("n")); Assert.assertEquals(true, selectItemMap.containsKey("s")); Assert.assertEquals("n", columnNameAliasMapping.get("name")); Assert.assertEquals("s", columnNameAliasMapping.get("score")); } @Test public void testOrderbyAlias() { String sql = "select test.name as n, test.score as s from test order by s"; SQLParsedResult parseWithoutCache = SQLParser.parseWithoutCache(sql); Map<String, SQLObjectImpl> selectItemMap = parseWithoutCache.getMergeContext().getSelectItemMap(); Map<String, String> columnNameAliasMapping = parseWithoutCache.getMergeContext().getColumnNameAliasMapping(); Assert.assertEquals(true, selectItemMap.containsKey("n")); Assert.assertEquals(true, selectItemMap.containsKey("s")); Assert.assertEquals("n", columnNameAliasMapping.get("name")); Assert.assertEquals("s", columnNameAliasMapping.get("score")); } @Test public void testReWriteFunction() { String sql = "select * from test"; String rewrite = new SqlToCountSqlRewrite().rewrite(sql, null); Assert.assertEquals(rewrite, "SELECT COUNT(*) AS zebra_count\nFROM test"); } @SuppressWarnings("unused") private class TestEntity { private int id; private String name; private int score; private String type; private int classid; public int getId() { return id; } public void setId(int id) { this.id = id; } public String getName() { return name; } public void setName(String name) { this.name = name; } public int getScore() { return score; } public void setScore(int score) { this.score = score; } public String getType() { return type; } public void setType(String type) { this.type = type; } public int getClassid() { return classid; } public void setClassid(int classid) { this.classid = classid; } } public List<TestEntity> popResult(ResultSet rs) throws SQLException, IOException { ArrayList<TestEntity> result = new ArrayList<TestEntity>(); while (rs.next()) { TestEntity entity = new TestEntity(); ResultSetMetaData metaData = rs.getMetaData(); for (int i = metaData.getColumnCount(); i > 0; i--) { String columnName = metaData.getColumnName(i).toLowerCase(); Class<?>[] params = new Class<?>[1]; try { params[0] = entity.getClass().getDeclaredField(columnName).getType(); } catch (NoSuchFieldException e1) { e1.printStackTrace(); } catch (SecurityException e1) { e1.printStackTrace(); } String methodName = "set" + columnName.substring(0, 1).toUpperCase() + columnName.substring(1); Method method = null; try { method = entity.getClass().getMethod(methodName, params); } catch (NoSuchMethodException e) { continue; } try { method.invoke(entity, rs.getObject(i)); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } } result.add(entity); } return result; } }