package io.mycat.parser.druid; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; import io.mycat.config.model.SchemaConfig; import io.mycat.config.model.TableConfig; import io.mycat.route.RouteResultset; import io.mycat.route.parser.druid.impl.DruidUpdateParser; import org.junit.Assert; import org.junit.Test; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.sql.SQLNonTransientException; import java.util.ArrayList; import java.util.List; import java.util.Map; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * @author Hash Zhang * @version 1.0.0 * @date 2016/7/7 */ public class DruidUpdateParserTest { /** * 测试单表更新分片字段 * @throws NoSuchMethodException */ @Test public void testUpdateShardColumn() throws NoSuchMethodException{ throwExceptionParse("update hotnews set id = 1 where name = 234;", true); throwExceptionParse("update hotnews set id = 1 where id = 3;", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id = 1 and name = '234'", false); throwExceptionParse("update hotnews set id = 1, name = '123' where id = 1 or name = '234'", true); throwExceptionParse("update hotnews set id = 'A', name = '123' where id = 'A' and name = '234'", false); throwExceptionParse("update hotnews set id = 'A', name = '123' where id = 'A' or name = '234'", true); throwExceptionParse("update hotnews set id = 1.5, name = '123' where id = 1.5 and name = '234'", false); throwExceptionParse("update hotnews set id = 1.5, name = '123' where id = 1.5 or name = '234'", true); throwExceptionParse("update hotnews set id = 1, name = '123' where name = '234' and (id = 1 or age > 3)", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id = 1 and (name = '234' or age > 3)", false); // 子查询,特殊的运算符between等情况 throwExceptionParse("update hotnews set id = 1, name = '123' where id = 1 and name in (select name from test)", false); throwExceptionParse("update hotnews set id = 1, name = '123' where name = '123' and id in (select id from test)", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id between 1 and 3", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id between 1 and 3 and name = '234'", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id between 1 and 3 or name = '234'", true); throwExceptionParse("update hotnews set id = 1, name = '123' where id = 1 and name between '124' and '234'", false); } /** * 测试单表别名更新分片字段 * @throws NoSuchMethodException */ @Test public void testAliasUpdateShardColumn() throws NoSuchMethodException{ throwExceptionParse("update hotnews h set h.id = 1 where h.name = 234;", true); throwExceptionParse("update hotnews h set h.id = 1 where h.id = 3;", true); throwExceptionParse("update hotnews h set h.id = 1, h.name = '123' where h.id = 1 and h.name = '234'", false); throwExceptionParse("update hotnews h set h.id = 1, h.name = '123' where h.id = 1 or h.name = '234'", true); throwExceptionParse("update hotnews h set h.id = 'A', h.name = '123' where h.id = 'A' and h.name = '234'", false); throwExceptionParse("update hotnews h set h.id = 'A', h.name = '123' where h.id = 'A' or h.name = '234'", true); throwExceptionParse("update hotnews h set h.id = 1.5, h.name = '123' where h.id = 1.5 and h.name = '234'", false); throwExceptionParse("update hotnews h set h.id = 1.5, h.name = '123' where h.id = 1.5 or h.name = '234'", true); throwExceptionParse("update hotnews h set id = 1, h.name = '123' where h.id = 1 and h.name = '234'", false); throwExceptionParse("update hotnews h set h.id = 1, h.name = '123' where id = 1 or h.name = '234'", true); throwExceptionParse("update hotnews h set h.id = 1, h.name = '123' where h.name = '234' and (h.id = 1 or h.age > 3)", true); throwExceptionParse("update hotnews h set h.id = 1, h.name = '123' where h.id = 1 and (h.name = '234' or h.age > 3)", false); } public void throwExceptionParse(String sql, boolean throwException) throws NoSuchMethodException { MySqlStatementParser parser = new MySqlStatementParser(sql); List<SQLStatement> statementList = parser.parseStatementList(); SQLStatement sqlStatement = statementList.get(0); MySqlUpdateStatement update = (MySqlUpdateStatement) sqlStatement; SchemaConfig schemaConfig = mock(SchemaConfig.class); Map<String, TableConfig> tables = mock(Map.class); TableConfig tableConfig = mock(TableConfig.class); String tableName = "hotnews"; when((schemaConfig).getTables()).thenReturn(tables); when(tables.get(tableName)).thenReturn(tableConfig); when(tableConfig.getParentTC()).thenReturn(null); RouteResultset routeResultset = new RouteResultset(sql, 11); Class c = DruidUpdateParser.class; Method method = c.getDeclaredMethod("confirmShardColumnNotUpdated", new Class[]{SQLUpdateStatement.class, SchemaConfig.class, String.class, String.class, String.class, RouteResultset.class}); method.setAccessible(true); try { method.invoke(c.newInstance(), update, schemaConfig, tableName, "ID", "", routeResultset); if (throwException) { System.out.println("未抛异常,解析通过则不对!"); Assert.assertTrue(false); } else { System.out.println("未抛异常,解析通过,此情况分片字段可能在update语句中但是实际不会被更新"); Assert.assertTrue(true); } } catch (Exception e) { if (throwException) { System.out.println(e.getCause().getClass()); Assert.assertTrue(e.getCause() instanceof SQLNonTransientException); System.out.println("抛异常原因为SQLNonTransientException则正确"); } else { System.out.println("抛异常,需要检查"); Assert.assertTrue(false); } } } /* * 添加一个static方法用于打印一个SQL的where子句,比如这样的一条SQL: * update mytab t set t.ptn_col = 'A', col1 = 3 where ptn_col = 'A' and (col1 = 4 or col2 > 5); * where子句的语法树如下 * AND * / \ * = OR * / \ / \ * ptn_col 'A' = > * / \ / \ * col1 4 col2 5 * 其输出如下,(按层输出,并且每层最后输出下一层的节点数目) * BooleanAnd Num of nodes in next level: 2 * Equality BooleanOr Num of nodes in next level: 4 * ptn_col 'A' Equality Equality Num of nodes in next level: 4 * col1 4 col2 5 Num of nodes in next level: 0 * * 因为大部分的update的where子句都比较简单,按层次打印应该足够清晰,未来可以完全按照逻辑打印类似上面的整棵树结构 */ public static void printWhereClauseAST(SQLExpr sqlExpr) { // where子句的AST sqlExpr可以通过 MySqlUpdateStatement.getWhere(); 获得 if (sqlExpr == null) return; ArrayList<SQLExpr> exprNode = new ArrayList<>(); int i = 0, curLevel = 1, nextLevel = 0; SQLExpr iterExpr; exprNode.add(sqlExpr); while (true) { iterExpr = exprNode.get(i++); if (iterExpr == null) break; if (iterExpr instanceof SQLBinaryOpExpr) { System.out.print(((SQLBinaryOpExpr) iterExpr).getOperator()); } else { System.out.print(iterExpr.toString()); } System.out.print("\t"); curLevel--; if (iterExpr instanceof SQLBinaryOpExpr) { if (((SQLBinaryOpExpr) iterExpr).getLeft() != null) { exprNode.add(((SQLBinaryOpExpr) iterExpr).getLeft()); nextLevel++; } if (((SQLBinaryOpExpr) iterExpr).getRight() != null) { exprNode.add(((SQLBinaryOpExpr) iterExpr).getRight()); nextLevel++; } } if (curLevel == 0) { System.out.println("\t\tNum of nodes in next level: " + nextLevel); curLevel = nextLevel; nextLevel = 0; } if (exprNode.size() == i) break; } } }