Java Code Examples for com.alibaba.druid.sql.visitor.SchemaStatVisitor

The following examples show how to use com.alibaba.druid.sql.visitor.SchemaStatVisitor. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: Mycat2   Source File: SchemaUtil.java    License: GNU General Public License v3.0 6 votes vote down vote up
private static SchemaInfo parseTables(SQLStatement stmt, SchemaStatVisitor schemaStatVisitor) {

        stmt.accept(schemaStatVisitor);
        String key = schemaStatVisitor.getCurrentTable();
        if (key != null && key.contains("`")) {
            key = key.replaceAll("`", "");
        }

        if (key != null) {
            SchemaInfo schemaInfo = new SchemaInfo();
            int pos = key.indexOf(".");
            if (pos > 0) {
                schemaInfo.schema = key.substring(0, pos);
                schemaInfo.table = key.substring(pos + 1);
            } else {
                schemaInfo.table = key;
            }
            return schemaInfo;
        }

        return null;
    }
 
Example 2
Source Project: DataLink   Source File: SQLStatementHolder.java    License: Apache License 2.0 5 votes vote down vote up
public SQLStatementHolder(SQLStatement sqlStatement, SchemaStatVisitor schemaStatVisitor, MediaSourceType mediaSourceType) {
    this.sqlStatement = sqlStatement;
    this.schemaStatVisitor = schemaStatVisitor;
    this.mediaSourceType = mediaSourceType;
    this.sqlString = buildSqlString();
    this.sqlCheckItems = Lists.newArrayList();
}
 
Example 3
Source Project: Mycat2   Source File: DruidParserFactory.java    License: GNU General Public License v3.0 5 votes vote down vote up
public static DruidParser create(SchemaConfig schema, SQLStatement statement, SchemaStatVisitor visitor)
{
    DruidParser parser = null;
    if (statement instanceof SQLSelectStatement)
    {
        if(schema.isNeedSupportMultiDBType())
        {
            parser = getDruidParserForMultiDB(schema, statement, visitor);

        }

        if (parser == null)
        {
            parser = new DruidSelectParser();
        }
    } else if (statement instanceof MySqlInsertStatement)
    {
        parser = new DruidInsertParser();
    } else if (statement instanceof MySqlDeleteStatement)
    {
        parser = new DruidDeleteParser();
    } else if (statement instanceof MySqlCreateTableStatement)
    {
        parser = new DruidCreateTableParser();
    } else if (statement instanceof MySqlUpdateStatement)
    {
        parser = new DruidUpdateParser();
    } else if (statement instanceof SQLAlterTableStatement)
    {
        parser = new DruidAlterTableParser();
    } else if (statement instanceof MySqlLockTableStatement) {
    	parser = new DruidLockTableParser();
    } else
    {
        parser = new DefaultDruidParser();
    }

    return parser;
}
 
Example 4
Source Project: Mycat2   Source File: DruidParserFactory.java    License: GNU General Public License v3.0 5 votes vote down vote up
private static List<String> parseTables(SQLStatement stmt, SchemaStatVisitor schemaStatVisitor)
{
    List<String> tables = new ArrayList<>();
    stmt.accept(schemaStatVisitor);

    if (schemaStatVisitor.getAliasMap() != null)
    {
        for (Map.Entry<String, String> entry : schemaStatVisitor.getAliasMap().entrySet())
        {
            String key = entry.getKey();
            String value = entry.getValue();
            if (value != null && value.indexOf("`") >= 0)
            {
                value = value.replaceAll("`", "");
            }
            //表名前面带database的,去掉
            if (key != null)
            {
                int pos = key.indexOf("`");
                if (pos > 0)
                {
                    key = key.replaceAll("`", "");
                }
                pos = key.indexOf(".");
                if (pos > 0)
                {
                    key = key.substring(pos + 1);
                }

                if (key.equals(value))
                {
                    tables.add(key.toUpperCase());
                }
            }
        }

    }
    return tables;
}
 
Example 5
Source Project: DataLink   Source File: SQLStatementHolder.java    License: Apache License 2.0 4 votes vote down vote up
public SchemaStatVisitor getSchemaStatVisitor() {
    return schemaStatVisitor;
}
 
Example 6
Source Project: DataLink   Source File: SQLStatementHolder.java    License: Apache License 2.0 4 votes vote down vote up
public void setSchemaStatVisitor(SchemaStatVisitor schemaStatVisitor) {
    this.schemaStatVisitor = schemaStatVisitor;
}
 
Example 7
Source Project: Mycat2   Source File: MycatPrivileges.java    License: GNU General Public License v3.0 4 votes vote down vote up
@Override
public boolean checkDmlPrivilege(String user, String schema, String sql) {

	if ( schema == null ) {
		return true;
	}
	
	boolean isPassed = false;

	MycatConfig conf = MycatServer.getInstance().getConfig();
	UserConfig userConfig = conf.getUsers().get(user);
	if (userConfig != null) {
		
		UserPrivilegesConfig userPrivilege = userConfig.getPrivilegesConfig();
		if ( userPrivilege != null && userPrivilege.isCheck() ) {				
		
			UserPrivilegesConfig.SchemaPrivilege schemaPrivilege = userPrivilege.getSchemaPrivilege( schema );
			if ( schemaPrivilege != null ) {
	
				String tableName = null;
				int index = -1;
				
				//TODO 此处待优化,寻找更优SQL 解析器
				
				//修复bug
				// https://github.com/alibaba/druid/issues/1309
				//com.alibaba.druid.sql.parser.ParserException: syntax error, error in :'begin',expect END, actual EOF begin
				if ( sql != null && sql.length() == 5 && sql.equalsIgnoreCase("begin") ) {
					return true;
				}
				
				SQLStatementParser parser = new MycatStatementParser(sql);			
				SQLStatement stmt = parser.parseStatement();

				if (stmt instanceof MySqlReplaceStatement || stmt instanceof SQLInsertStatement ) {
					index = 0;
				} else if (stmt instanceof SQLUpdateStatement ) {
					index = 1;
				} else if (stmt instanceof SQLSelectStatement ) {
					index = 2;
				} else if (stmt instanceof SQLDeleteStatement ) {
					index = 3;
				}

				if ( index > -1) {
					
					SchemaStatVisitor schemaStatVisitor = new MycatSchemaStatVisitor();
					stmt.accept(schemaStatVisitor);
					String key = schemaStatVisitor.getCurrentTable();
					if ( key != null ) {
						
						if (key.contains("`")) {
							key = key.replaceAll("`", "");
						}
						
						int dotIndex = key.indexOf(".");
						if (dotIndex > 0) {
							tableName = key.substring(dotIndex + 1);
						} else {
							tableName = key;
						}							
						
						//获取table 权限, 此处不需要检测空值, 无设置则自动继承父级权限
						UserPrivilegesConfig.TablePrivilege tablePrivilege = schemaPrivilege.getTablePrivilege( tableName );
						if ( tablePrivilege.getDml()[index] > 0 ) {
							isPassed = true;
						}
						
					} else {
						//skip
						isPassed = true;
					}
					
					
				} else {						
					//skip
					isPassed = true;
				}
				
			} else {					
				//skip
				isPassed = true;
			}
			
		} else {
			//skip
			isPassed = true;
		}

	} else {
		//skip
		isPassed = true;
	}
	
	if( !isPassed ) {
		 ALARM.error(new StringBuilder().append(Alarms.DML_ATTACK ).append("[sql=").append( sql )
                    .append(",user=").append(user).append(']').toString());
	}
	
	return isPassed;
}
 
Example 8
@Override
public boolean visit(SQLBinaryOpExpr x) {
       x.getLeft().setParent(x);
       x.getRight().setParent(x);
       
       /*
        * fix bug 当 selectlist 存在多个子查询时, 主表没有别名的情况下.主表的查询条件 被错误的附加到子查询上.
        *  eg. select (select id from subtest2 where id = 1), (select id from subtest3 where id = 2) from subtest1 where id =4;
        *  像这样的子查询, subtest1 的 过滤条件  id = 4 .  被 加入到  subtest3 上. 加别名的情况下正常,不加别名,就会存在这个问题.
        *  这里设置好操作的是哪张表后,再进行判断.
        */
       String currenttable = x.getParent()==null?null: (String) x.getParent().getAttribute(SchemaStatVisitor.ATTR_TABLE);
       if(currenttable!=null){
       	this.setCurrentTable(currenttable);
       }
       
       switch (x.getOperator()) {
           case Equality:
           case LessThanOrEqualOrGreaterThan:
           case Is:
           case IsNot:
           case GreaterThan:
           case GreaterThanOrEqual:
           case LessThan:
           case LessThanOrEqual:
           case NotLessThan:
           case LessThanOrGreater:
		case NotEqual:
		case NotGreaterThan:
               handleCondition(x.getLeft(), x.getOperator().name, x.getRight());
               handleCondition(x.getRight(), x.getOperator().name, x.getLeft());
               handleRelationship(x.getLeft(), x.getOperator().name, x.getRight());
               break;
           case BooleanOr:
           	//永真条件,where条件抛弃
           	if(!RouterUtil.isConditionAlwaysTrue(x)) {
           		hasOrCondition = true;
           		
           		WhereUnit whereUnit = null;
           		if(conditions.size() > 0) {
           			whereUnit = new WhereUnit();
           			whereUnit.setFinishedParse(true);
           			whereUnit.addOutConditions(getConditions());
           			WhereUnit innerWhereUnit = new WhereUnit(x);
           			whereUnit.addSubWhereUnit(innerWhereUnit);
           		} else {
           			whereUnit = new WhereUnit(x);
           			whereUnit.addOutConditions(getConditions());
           		}
           		whereUnits.add(whereUnit);
           	}
           	return false;
           case Like:
           case NotLike:
           default:
               break;
       }
       return true;
   }
 
Example 9
Source Project: Mycat2   Source File: DefaultDruidParser.java    License: GNU General Public License v3.0 4 votes vote down vote up
private List<RouteCalculateUnit> buildRouteCalculateUnits(SchemaStatVisitor visitor, List<List<Condition>> conditionList) {
	List<RouteCalculateUnit> retList = new ArrayList<RouteCalculateUnit>();

	//遍历condition ,找分片字段
	for(int i = 0; i < conditionList.size(); i++) {
		RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit();
		for(Condition condition : conditionList.get(i)) {
			List<Object> values = condition.getValues();
			if(values.size() == 0) {
				continue;  
			}
			if(checkConditionValues(values)) {
				String columnName = StringUtil.removeBackquote(condition.getColumn().getName().toUpperCase());
				String tableName = StringUtil.removeBackquote(condition.getColumn().getTable().toUpperCase());
				int index = 0;

					if(visitor.getAliasMap() != null && visitor.getAliasMap().get(tableName) != null
						&& !visitor.getAliasMap().get(tableName).equals(tableName)) {
					tableName = visitor.getAliasMap().get(tableName);
				}
				//处理schema.table的情况
				if ((index = tableName.indexOf(".")) != -1) {
					tableName = tableName.substring(index + 1);
				}
				tableName = tableName.toUpperCase();
				//确保表名是大写
				if(visitor.getAliasMap() != null && visitor.getAliasMap().get(tableName) == null) {//子查询的别名条件忽略掉,不参数路由计算,否则后面找不到表
					continue;
				}
				
				String operator = condition.getOperator();
				
				//只处理between ,in和=3中操作符
				if(operator.equals("between")) {
					RangeValue rv = new RangeValue(values.get(0), values.get(1), RangeValue.EE);
							routeCalculateUnit.addShardingExpr(tableName.toUpperCase(), columnName, rv);
				} else if(operator.equals("=") || operator.toLowerCase().equals("in")){ //只处理=号和in操作符,其他忽略
							routeCalculateUnit.addShardingExpr(tableName.toUpperCase(), columnName, values.toArray());
				}
			}
		}
		retList.add(routeCalculateUnit);
	}
	return retList;
}
 
Example 10
Source Project: Mycat2   Source File: DruidParserFactory.java    License: GNU General Public License v3.0 4 votes vote down vote up
private static DruidParser getDruidParserForMultiDB(SchemaConfig schema, SQLStatement statement, SchemaStatVisitor visitor)
{
    DruidParser parser=null;
    //先解出表,判断表所在db的类型,再根据不同db类型返回不同的解析
    /**
     * 不能直接使用visitor变量,防止污染后续sql解析
     * @author SvenAugustus
     */
    SchemaStatVisitor _visitor = SchemaStatVisitorFactory.create(schema);
    List<String> tables = parseTables(statement, _visitor);
    for (String table : tables)
    {
        Set<String> dbTypes =null;
        TableConfig tableConfig = schema.getTables().get(table);
        if(tableConfig==null)
        {
            dbTypes=new HashSet<>();
            dbTypes.add(schema.getDefaultDataNodeDbType())  ;
        }else
        {
            dbTypes = tableConfig.getDbTypes();
        }
        if (dbTypes.contains("oracle"))
        {
            parser = new DruidSelectOracleParser();
            ((DruidSelectOracleParser)parser).setInvocationHandler(SqlMethodInvocationHandlerFactory.getForOracle());
            break;
        } else if (dbTypes.contains("db2"))
        {
            parser = new DruidSelectDb2Parser();
            break;
        } else if (dbTypes.contains("sqlserver"))
        {
            parser = new DruidSelectSqlServerParser();
            break;
        } else if (dbTypes.contains("postgresql"))
        {
            parser = new DruidSelectPostgresqlParser();
            ((DruidSelectPostgresqlParser)parser).setInvocationHandler(SqlMethodInvocationHandlerFactory.getForPgsql());
            break;
        }
    }
    return parser;
}
 
Example 11
Source Project: Mycat2   Source File: DQLRouteTest.java    License: GNU General Public License v3.0 4 votes vote down vote up
@SuppressWarnings("unchecked")
private List<RouteCalculateUnit> visitorParse(RouteResultset rrs, SQLStatement stmt, MycatSchemaStatVisitor visitor) throws Exception {

	stmt.accept(visitor);

	List<List<Condition>> mergedConditionList = new ArrayList<List<Condition>>();
	if (visitor.hasOrCondition()) {// 包含or语句
		// TODO
		// 根据or拆分
		mergedConditionList = visitor.splitConditions();
	} else {// 不包含OR语句
		mergedConditionList.add(visitor.getConditions());
	}

	if (visitor.getAliasMap() != null) {
		for (Map.Entry<String, String> entry : visitor.getAliasMap().entrySet()) {
			String key = entry.getKey();
			String value = entry.getValue();
			if (key != null && key.indexOf("`") >= 0) {
				key = key.replaceAll("`", "");
			}
			if (value != null && value.indexOf("`") >= 0) {
				value = value.replaceAll("`", "");
			}
			// 表名前面带database的,去掉
			if (key != null) {
				int pos = key.indexOf(".");
				if (pos > 0) {
					key = key.substring(pos + 1);
				}
			}

			if (key.equals(value)) {
				ctx.addTable(key.toUpperCase());
			}
			// else {
			// tableAliasMap.put(key, value);
			// }
			tableAliasMap.put(key.toUpperCase(), value);
		}
		visitor.getAliasMap().putAll(tableAliasMap);
		ctx.setTableAliasMap(tableAliasMap);
	}

	//利用反射机制单元测试DefaultDruidParser类的私有方法buildRouteCalculateUnits
	Class<?> clazz = Class.forName("io.mycat.route.parser.druid.impl.DefaultDruidParser");
	Method buildRouteCalculateUnits = clazz.getDeclaredMethod("buildRouteCalculateUnits",
			new Class[] { SchemaStatVisitor.class, List.class });
	//System.out.println("buildRouteCalculateUnits:\t" + buildRouteCalculateUnits);
	Object newInstance = clazz.newInstance();
	buildRouteCalculateUnits.setAccessible(true);
	Object returnValue = buildRouteCalculateUnits.invoke(newInstance,
			new Object[] { visitor, mergedConditionList });
	List<RouteCalculateUnit> retList = new ArrayList<RouteCalculateUnit>();
	if (returnValue instanceof ArrayList<?>) {
		retList.add(((ArrayList<RouteCalculateUnit>)returnValue).get(0));
		//retList = (ArrayList<RouteCalculateUnit>)returnValue;
		//System.out.println(taskList.get(0).getTablesAndConditions().values());			
	}
	return retList;
}
 
Example 12
public void setVisitor(SchemaStatVisitor visitor) {
	
	this.visitor = visitor;
}
 
Example 13
public SchemaStatVisitor getVisitor(){
	
	return this.visitor;
}
 
Example 14
/**
 * 创建
 * 
 * @return
 */
public static SchemaStatVisitor create(SchemaConfig schema) {
  SchemaStatVisitor visitor = new MycatSchemaStatVisitor();
  return visitor;
}