Java Code Examples for com.alibaba.druid.sql.ast.SQLStatement#accept()

The following examples show how to use com.alibaba.druid.sql.ast.SQLStatement#accept() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SchemaUtil.java    From Mycat2 with GNU General Public License v3.0 6 votes vote down vote up
private static SchemaInfo parseTables(SQLStatement stmt, SchemaStatVisitor schemaStatVisitor) {

        stmt.accept(schemaStatVisitor);
        String key = schemaStatVisitor.getCurrentTable();
        if (key != null && key.contains("`")) {
            key = key.replaceAll("`", "");
        }

        if (key != null) {
            SchemaInfo schemaInfo = new SchemaInfo();
            int pos = key.indexOf(".");
            if (pos > 0) {
                schemaInfo.schema = key.substring(0, pos);
                schemaInfo.table = key.substring(pos + 1);
            } else {
                schemaInfo.table = key;
            }
            return schemaInfo;
        }

        return null;
    }
 
Example 2
Source File: MySQLDialect.java    From Zebra with Apache License 2.0 6 votes vote down vote up
@Override
public String getCountSql(String sql) {
	SQLStatementParser parser = new MySqlStatementParser(sql);
	List<SQLStatement> stmtList = parser.parseStatementList();

	// 将AST通过visitor输出
	StringBuilder out = new StringBuilder();
	MysqlCountOutputVisitor visitor = new MysqlCountOutputVisitor(out);

	for (SQLStatement stmt : stmtList) {
		if (stmt instanceof SQLSelectStatement) {
			stmt.accept(visitor);
			out.append(";");
		}
	}

	return out.toString();
}
 
Example 3
Source File: DefaultDruidParser.java    From dble with GNU General Public License v2.0 6 votes vote down vote up
@Override
public SchemaConfig visitorParse(SchemaConfig schemaConfig, RouteResultset rrs, SQLStatement stmt, ServerSchemaStatVisitor visitor, ServerConnection sc, boolean isExplain)
        throws SQLException {
    stmt.accept(visitor);
    if (visitor.getNotSupportMsg() != null) {
        throw new SQLNonTransientException(visitor.getNotSupportMsg());
    }
    String schemaName = null;
    if (schemaConfig != null) {
        schemaName = schemaConfig.getName();
    }
    Map<String, String> tableAliasMap = getTableAliasMap(schemaName, visitor.getAliasMap());
    ctx.setRouteCalculateUnits(ConditionUtil.buildRouteCalculateUnits(visitor.getAllWhereUnit(), tableAliasMap, schemaName));

    return schemaConfig;
}
 
Example 4
Source File: ReplaceTableNameVisitorTest.java    From baymax with Apache License 2.0 6 votes vote down vote up
public void test(String sql, String logicName, String targetName){

        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement statement = parser.parseStatement();

        ReplaceTableNameVisitor replaceVisitor = new ReplaceTableNameVisitor(logicName, targetName);
        StringBuilder out = new StringBuilder();
        MySqlOutputVisitor outPutVisitor = new MySqlOutputVisitor(out);
        // 替换表名
        statement.accept(replaceVisitor);
        // 输出sql
        statement.accept(outPutVisitor);

        System.out.println();
        System.out.println(sql);
        System.out.println(out.toString());
        // 输出sql后要还原statement以便下次替换表名
        replaceVisitor.reset();
    }
 
Example 5
Source File: QueryConditionAnalyzer.java    From Mycat2 with GNU General Public License v3.0 5 votes vote down vote up
/**
 * 解析 SQL 获取指定表及条件列的值
 * 
 * @param sql
 * @param tableName
 * @param colnumName
 * @return
 */
public List<Object> parseConditionValues(String sql, String tableName, String colnumName)  {
	
	List<Object> values = null;
	
	if ( sql != null && tableName != null && columnName != null ) {
	
		values = new ArrayList<Object>();
		
		MySqlStatementParser parser = new MySqlStatementParser(sql);
		SQLStatement stmt = parser.parseStatement();
		
		MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
		stmt.accept(visitor);
		
		String currentTable = visitor.getCurrentTable();
		if ( tableName.equalsIgnoreCase( currentTable ) ) {
			
			List<Condition> conditions = visitor.getConditions();
			for(Condition condition: conditions) {
				
				String ccN = condition.getColumn().getName();
				ccN = fixName(ccN);
				
				if ( colnumName.equalsIgnoreCase( ccN ) ) {					
					List<Object> ccVL = condition.getValues();
					values.addAll( ccVL );
				}
			}
		}				
	}
	return values;
}
 
Example 6
Source File: ShardLimitSqlWithConditionRewrite.java    From Zebra with Apache License 2.0 5 votes vote down vote up
public String rewrite(String limitSql, RowData startData, RowData endData, MergeContext context, List<Object> params) {
	SQLStatement stmt = SQLParser.parseWithoutCache(limitSql).getStmt();

	StringBuilder out = new StringBuilder();
	ShardLimitSqlConditionVisitor visitor = new ShardLimitSqlConditionVisitor(out, startData, endData, context,
	      params);

	stmt.accept(visitor);

	return out.toString();
}
 
Example 7
Source File: DruidParserFactory.java    From Mycat2 with GNU General Public License v3.0 5 votes vote down vote up
private static List<String> parseTables(SQLStatement stmt, SchemaStatVisitor schemaStatVisitor)
{
    List<String> tables = new ArrayList<>();
    stmt.accept(schemaStatVisitor);

    if (schemaStatVisitor.getAliasMap() != null)
    {
        for (Map.Entry<String, String> entry : schemaStatVisitor.getAliasMap().entrySet())
        {
            String key = entry.getKey();
            String value = entry.getValue();
            if (value != null && value.indexOf("`") >= 0)
            {
                value = value.replaceAll("`", "");
            }
            //表名前面带database的,去掉
            if (key != null)
            {
                int pos = key.indexOf("`");
                if (pos > 0)
                {
                    key = key.replaceAll("`", "");
                }
                pos = key.indexOf(".");
                if (pos > 0)
                {
                    key = key.substring(pos + 1);
                }

                if (key.equals(value))
                {
                    tables.add(key.toUpperCase());
                }
            }
        }

    }
    return tables;
}
 
Example 8
Source File: SqlToCountSqlRewrite.java    From Zebra with Apache License 2.0 5 votes vote down vote up
public String rewrite(String sql, List<ParamContext> countParams) {
	MySqlStatementParser parser = new MySqlStatementParser(sql);
	SQLStatement stmt = parser.parseStatement();
	RewriteSqlToCountSqlVisitor visitor = new RewriteSqlToCountSqlVisitor(countParams);
	stmt.accept(visitor);

	return stmt.toString();
}
 
Example 9
Source File: SchemaUtil.java    From dble with GNU General Public License v2.0 5 votes vote down vote up
public static boolean isNoSharding(ServerConnection source, SQLTableSource tables, SQLStatement stmt, SQLStatement childSelectStmt, String contextSchema, Set<String> schemas, StringPtr dataNode)
        throws SQLException {
    if (tables != null) {
        if (tables instanceof SQLExprTableSource) {
            if (!isNoSharding(source, (SQLExprTableSource) tables, stmt, childSelectStmt, contextSchema, schemas, dataNode)) {
                return false;
            }
        } else if (tables instanceof SQLJoinTableSource) {
            if (!isNoSharding(source, (SQLJoinTableSource) tables, stmt, childSelectStmt, contextSchema, schemas, dataNode)) {
                return false;
            }
        } else if (tables instanceof SQLSubqueryTableSource) {
            SQLSelectQuery sqlSelectQuery = ((SQLSubqueryTableSource) tables).getSelect().getQuery();
            if (!isNoSharding(source, sqlSelectQuery, stmt, new SQLSelectStatement(new SQLSelect(sqlSelectQuery)), contextSchema, schemas, dataNode)) {
                return false;
            }
        } else if (tables instanceof SQLUnionQueryTableSource) {
            if (!isNoSharding(source, ((SQLUnionQueryTableSource) tables).getUnion(), stmt, contextSchema, schemas, dataNode)) {
                return false;
            }
        } else {
            return false;
        }
    }
    ServerSchemaStatVisitor queryTableVisitor = new ServerSchemaStatVisitor();
    childSelectStmt.accept(queryTableVisitor);
    for (SQLSelect sqlSelect : queryTableVisitor.getSubQueryList()) {
        if (!isNoSharding(source, sqlSelect.getQuery(), stmt, new SQLSelectStatement(sqlSelect), contextSchema, schemas, dataNode)) {
            return false;
        }
    }
    return true;
}
 
Example 10
Source File: QueryConditionAnalyzer.java    From dble with GNU General Public License v2.0 5 votes vote down vote up
/**
 * parseConditionValues
 *
 * @param sql
 * @param table
 * @param column
 * @return
 */
public List<Object> parseConditionValues(String sql, String table, String column) {

    List<Object> values = null;

    if (sql != null && table != null && QueryConditionAnalyzer.this.columnName != null) {

        values = new ArrayList<>();

        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement stmt = parser.parseStatement();

        ServerSchemaStatVisitor visitor = new ServerSchemaStatVisitor();
        stmt.accept(visitor);

        String currentTable = visitor.getCurrentTable();
        if (table.equalsIgnoreCase(currentTable)) {

            List<Condition> conditions = visitor.getConditions();
            for (Condition condition : conditions) {

                String ccN = condition.getColumn().getName();
                ccN = fixName(ccN);

                if (column.equalsIgnoreCase(ccN)) {
                    List<Object> ccVL = condition.getValues();
                    values.addAll(ccVL);
                }
            }
        }
    }
    return values;
}
 
Example 11
Source File: DefaultDruidParser.java    From dble with GNU General Public License v2.0 5 votes vote down vote up
String statementToString(SQLStatement statement) {
    StringBuffer buf = new StringBuffer();
    MySqlOutputVisitor visitor = new MySqlOutputVisitor(buf);
    visitor.setShardingSupport(false);
    statement.accept(visitor);
    return buf.toString();
}
 
Example 12
Source File: ServerSchemaStatVisitorTest.java    From dble with GNU General Public License v2.0 5 votes vote down vote up
private List<WhereUnit> getAllWhereUnit(String sql) {
    SQLStatementParser parser = new MySqlStatementParser(sql);

    ServerSchemaStatVisitor visitor;
    //throw exception
    try {
        SQLStatement statement = parser.parseStatement();
        visitor = new ServerSchemaStatVisitor();
        statement.accept(visitor);
        return visitor.getAllWhereUnit();
    } catch (Exception e) {
        e.printStackTrace();
    }
    return  null;
}
 
Example 13
Source File: MycatPrivileges.java    From Mycat2 with GNU General Public License v3.0 4 votes vote down vote up
@Override
public boolean checkDmlPrivilege(String user, String schema, String sql) {

	if ( schema == null ) {
		return true;
	}
	
	boolean isPassed = false;

	MycatConfig conf = MycatServer.getInstance().getConfig();
	UserConfig userConfig = conf.getUsers().get(user);
	if (userConfig != null) {
		
		UserPrivilegesConfig userPrivilege = userConfig.getPrivilegesConfig();
		if ( userPrivilege != null && userPrivilege.isCheck() ) {				
		
			UserPrivilegesConfig.SchemaPrivilege schemaPrivilege = userPrivilege.getSchemaPrivilege( schema );
			if ( schemaPrivilege != null ) {
	
				String tableName = null;
				int index = -1;
				
				//TODO 此处待优化,寻找更优SQL 解析器
				
				//修复bug
				// https://github.com/alibaba/druid/issues/1309
				//com.alibaba.druid.sql.parser.ParserException: syntax error, error in :'begin',expect END, actual EOF begin
				if ( sql != null && sql.length() == 5 && sql.equalsIgnoreCase("begin") ) {
					return true;
				}
				
				SQLStatementParser parser = new MycatStatementParser(sql);			
				SQLStatement stmt = parser.parseStatement();

				if (stmt instanceof MySqlReplaceStatement || stmt instanceof SQLInsertStatement ) {
					index = 0;
				} else if (stmt instanceof SQLUpdateStatement ) {
					index = 1;
				} else if (stmt instanceof SQLSelectStatement ) {
					index = 2;
				} else if (stmt instanceof SQLDeleteStatement ) {
					index = 3;
				}

				if ( index > -1) {
					
					SchemaStatVisitor schemaStatVisitor = new MycatSchemaStatVisitor();
					stmt.accept(schemaStatVisitor);
					String key = schemaStatVisitor.getCurrentTable();
					if ( key != null ) {
						
						if (key.contains("`")) {
							key = key.replaceAll("`", "");
						}
						
						int dotIndex = key.indexOf(".");
						if (dotIndex > 0) {
							tableName = key.substring(dotIndex + 1);
						} else {
							tableName = key;
						}							
						
						//获取table 权限, 此处不需要检测空值, 无设置则自动继承父级权限
						UserPrivilegesConfig.TablePrivilege tablePrivilege = schemaPrivilege.getTablePrivilege( tableName );
						if ( tablePrivilege.getDml()[index] > 0 ) {
							isPassed = true;
						}
						
					} else {
						//skip
						isPassed = true;
					}
					
					
				} else {						
					//skip
					isPassed = true;
				}
				
			} else {					
				//skip
				isPassed = true;
			}
			
		} else {
			//skip
			isPassed = true;
		}

	} else {
		//skip
		isPassed = true;
	}
	
	if( !isPassed ) {
		 ALARM.error(new StringBuilder().append(Alarms.DML_ATTACK ).append("[sql=").append( sql )
                    .append(",user=").append(user).append(']').toString());
	}
	
	return isPassed;
}
 
Example 14
Source File: SqlParseUtils.java    From jeesuite-libs with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {
	String sql = "DELETE a1, a2 FROM t1 AS a1 INNER JOIN t2 AS a2 WHERE a1.id=a2.id;";
	 
	MySqlStatementParser parser = new MySqlStatementParser(sql);
	List<SQLStatement> statementList = parser.parseStatementList();
	SQLStatement statemen = statementList.get(0);
	 
	MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
	statemen.accept(visitor);
	
	System.out.println(visitor.getTables());
	
	System.out.println(visitor.getColumns());
	
	System.out.println(visitor.getConditions());
}
 
Example 15
Source File: SqlParseUtils.java    From azeroth with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {
    String sql = "DELETE a1, a2 FROM t1 AS a1 INNER JOIN t2 AS a2 WHERE a1.id=a2.id;";

    MySqlStatementParser parser = new MySqlStatementParser(sql);
    List<SQLStatement> statementList = parser.parseStatementList();
    SQLStatement statemen = statementList.get(0);

    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    statemen.accept(visitor);

    System.out.println(visitor.getTables());

    System.out.println(visitor.getColumns());

    System.out.println(visitor.getConditions());
}
 
Example 16
Source File: DQLRouteTest.java    From Mycat2 with GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
private List<RouteCalculateUnit> visitorParse(RouteResultset rrs, SQLStatement stmt, MycatSchemaStatVisitor visitor) throws Exception {

	stmt.accept(visitor);

	List<List<Condition>> mergedConditionList = new ArrayList<List<Condition>>();
	if (visitor.hasOrCondition()) {// 包含or语句
		// TODO
		// 根据or拆分
		mergedConditionList = visitor.splitConditions();
	} else {// 不包含OR语句
		mergedConditionList.add(visitor.getConditions());
	}

	if (visitor.getAliasMap() != null) {
		for (Map.Entry<String, String> entry : visitor.getAliasMap().entrySet()) {
			String key = entry.getKey();
			String value = entry.getValue();
			if (key != null && key.indexOf("`") >= 0) {
				key = key.replaceAll("`", "");
			}
			if (value != null && value.indexOf("`") >= 0) {
				value = value.replaceAll("`", "");
			}
			// 表名前面带database的,去掉
			if (key != null) {
				int pos = key.indexOf(".");
				if (pos > 0) {
					key = key.substring(pos + 1);
				}
			}

			if (key.equals(value)) {
				ctx.addTable(key.toUpperCase());
			}
			// else {
			// tableAliasMap.put(key, value);
			// }
			tableAliasMap.put(key.toUpperCase(), value);
		}
		visitor.getAliasMap().putAll(tableAliasMap);
		ctx.setTableAliasMap(tableAliasMap);
	}

	//利用反射机制单元测试DefaultDruidParser类的私有方法buildRouteCalculateUnits
	Class<?> clazz = Class.forName("io.mycat.route.parser.druid.impl.DefaultDruidParser");
	Method buildRouteCalculateUnits = clazz.getDeclaredMethod("buildRouteCalculateUnits",
			new Class[] { SchemaStatVisitor.class, List.class });
	//System.out.println("buildRouteCalculateUnits:\t" + buildRouteCalculateUnits);
	Object newInstance = clazz.newInstance();
	buildRouteCalculateUnits.setAccessible(true);
	Object returnValue = buildRouteCalculateUnits.invoke(newInstance,
			new Object[] { visitor, mergedConditionList });
	List<RouteCalculateUnit> retList = new ArrayList<RouteCalculateUnit>();
	if (returnValue instanceof ArrayList<?>) {
		retList.add(((ArrayList<RouteCalculateUnit>)returnValue).get(0));
		//retList = (ArrayList<RouteCalculateUnit>)returnValue;
		//System.out.println(taskList.get(0).getTablesAndConditions().values());			
	}
	return retList;
}
 
Example 17
Source File: DefaultDruidParser.java    From Mycat2 with GNU General Public License v3.0 4 votes vote down vote up
/**
	 * 子类可覆盖(如果该方法解析得不到表名、字段等信息的,就覆盖该方法,覆盖成空方法,然后通过statementPparse去解析)
	 * 通过visitor解析:有些类型的Statement通过visitor解析得不到表名、
	 * @param stmt
	 */
	@Override
	public void visitorParse(RouteResultset rrs, SQLStatement stmt,MycatSchemaStatVisitor visitor) throws SQLNonTransientException{

		stmt.accept(visitor);
		ctx.setVisitor(visitor);

		if(stmt instanceof SQLSelectStatement){
			SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
			if(query instanceof MySqlSelectQueryBlock){
				if(((MySqlSelectQueryBlock)query).isForUpdate()){
					rrs.setSelectForUpdate(true);
				}
			}
		}

		List<List<Condition>> mergedConditionList = new ArrayList<List<Condition>>();
		if(visitor.hasOrCondition()) {//包含or语句
			//TODO
			//根据or拆分
			mergedConditionList = visitor.splitConditions();
		} else {//不包含OR语句
			mergedConditionList.add(visitor.getConditions());
		}
		
		if(visitor.isHasChange()){	// 在解析的过程中子查询被改写了.需要更新ctx.
			ctx.setSql(stmt.toString());
			rrs.setStatement(ctx.getSql());
		}
		
		if(visitor.getAliasMap() != null) {
			for(Map.Entry<String, String> entry : visitor.getAliasMap().entrySet()) {
				String key = entry.getKey();
				String value = entry.getValue();
				if(key != null && key.indexOf("`") >= 0) {
					key = key.replaceAll("`", "");
				}
				if(value != null && value.indexOf("`") >= 0) {
					value = value.replaceAll("`", "");
				}
				//表名前面带database的,去掉
				if(key != null) {
					int pos = key.indexOf(".");
					if(pos> 0) {
						key = key.substring(pos + 1);
					}
					
					tableAliasMap.put(key.toUpperCase(), value);
				}
				

//				else {
//					tableAliasMap.put(key, value);
//				}

			}
			ctx.addTables(visitor.getTables());
			
			visitor.getAliasMap().putAll(tableAliasMap);
			ctx.setTableAliasMap(tableAliasMap);
		}
		ctx.setRouteCalculateUnits(this.buildRouteCalculateUnits(visitor, mergedConditionList));
	}
 
Example 18
Source File: ShardLimitSqlSplitRewrite.java    From Zebra with Apache License 2.0 3 votes vote down vote up
public String rewrite(SQLParsedResult limitSql, int splitNum, List<Object> params) {
	SQLStatement stmt = limitSql.getStmt();

	StringBuilder out = new StringBuilder();
	ShardLimitSqlSplitVisior visitor = new ShardLimitSqlSplitVisior(out, splitNum, params);

	stmt.accept(visitor);

	return out.toString();
}