package com.dfire.platform.alchemy.api.util; import com.dfire.platform.alchemy.api.common.Alias; import org.apache.calcite.config.Lex; import org.apache.calcite.sql.SqlAsOperator; import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlBinaryOperator; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlInsert; import org.apache.calcite.sql.SqlJoin; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.commons.lang3.StringUtils; import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList; import org.apache.flink.calcite.shaded.com.google.common.collect.Lists; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; /** * @author congbai * @date 2019/5/21 */ public class SideParser { private static final SqlParser.Config CONFIG = SqlParser.configBuilder().setLex(Lex.MYSQL).build(); public static Deque<SqlNode> parse(String sql) throws SqlParseException { SqlParser sqlParser = SqlParser.create(sql, CONFIG); SqlNode sqlNode = sqlParser.parseStmt(); Deque<SqlNode> deque = new ArrayDeque<>(); parse(sqlNode, deque); return deque; } public static void parse(SqlNode sqlNode, Deque<SqlNode> deque) { deque.offer(sqlNode); SqlKind sqlKind = sqlNode.getKind(); switch (sqlKind) { case INSERT: SqlNode sqlSource = ((SqlInsert)sqlNode).getSource(); parse(sqlSource, deque); break; case SELECT: SqlNode sqlFrom = ((SqlSelect)sqlNode).getFrom(); parse(sqlFrom, deque); break; case JOIN: SqlNode sqlLeft = ((SqlJoin)sqlNode).getLeft(); SqlNode sqlRight = ((SqlJoin)sqlNode).getRight(); parse(sqlLeft, deque); parse(sqlRight, deque); break; case AS: SqlNode sqlAs = ((SqlBasicCall)sqlNode).getOperands()[0]; parse(sqlAs, deque); break; default: return; } } public static Alias getTableName(SqlNode sqlNode) { SqlKind sqlKind = sqlNode.getKind(); Alias alias; switch (sqlKind) { case IDENTIFIER: SqlIdentifier sqlIdentifier = (SqlIdentifier)sqlNode; alias = new Alias(sqlIdentifier.names.get(0), sqlIdentifier.names.get(0)); break; case AS: SqlBasicCall sqlBasicCall = (SqlBasicCall)sqlNode; SqlNode first = sqlBasicCall.getOperands()[0]; SqlNode second = sqlBasicCall.getOperands()[1]; if (first.getKind() == SqlKind.IDENTIFIER) { alias = new Alias(((SqlIdentifier)first).names.get(0), ((SqlIdentifier)second).names.get(0)); } else { alias = new Alias(((SqlIdentifier)second).names.get(0), ((SqlIdentifier)second).names.get(0)); } break; default: throw new UnsupportedOperationException("暂时不支持" + sqlKind); } return alias; } public static void rewrite(SqlNode sqlNode, SqlSelect sqlSelect) { SqlKind sqlKind = sqlNode.getKind(); switch (sqlKind) { case INSERT: SqlInsert sqlInsert = ((SqlInsert)sqlNode); sqlInsert.setSource(sqlSelect); break; case SELECT: SqlSelect select = (SqlSelect)sqlNode; select.setFrom(sqlSelect); break; case AS: SqlBasicCall basicCall = (SqlBasicCall)sqlNode; basicCall.setOperand(0, sqlSelect); break; default: throw new UnsupportedOperationException(sqlKind + "目前不支持维表操作"); } } /** * select a.name , a.age , FUN(a.weight) as weight from test --> { name , age , weight} * @param selectList * @return */ public static List<String> findSelectField(SqlNodeList selectList){ List<SqlNode> nodes = selectList.getList(); List<String> fields = new ArrayList<>(); for (SqlNode node : nodes){ SqlKind kind = node.getKind(); String field; switch (kind){ case AS: SqlBasicCall call = (SqlBasicCall) node; field = findField(call.operand(0)); break; case IDENTIFIER: field = findField(node); break; default: throw new UnsupportedOperationException("Don't supported findSelectField in" + node); } if (StringUtils.isEmpty(field)){ // a.* return Collections.emptyList(); }else{ fields.add(field); } } return fields; } private static String findField(SqlNode sqlNode){ SqlKind kind = sqlNode.getKind(); switch (kind){ case IDENTIFIER: SqlIdentifier identifier = (SqlIdentifier) sqlNode; ImmutableList<String> names = identifier.names; if (names.size()== 1){ return names.get(0); }else if(names.size() == 2){ return names.get(1); } default: throw new UnsupportedOperationException("Don't supported findField in" + sqlNode); } } public static List<String> findConditionFields(SqlNode conditionNode, String specifyTableName){ List<SqlNode> sqlNodeList = Lists.newArrayList(); if(conditionNode.getKind() == SqlKind.AND){ sqlNodeList.addAll(Lists.newArrayList(((SqlBasicCall)conditionNode).getOperands())); }else{ sqlNodeList.add(conditionNode); } List<String> conditionFields = Lists.newArrayList(); for(SqlNode sqlNode : sqlNodeList){ if(sqlNode.getKind() != SqlKind.EQUALS){ throw new RuntimeException("not equal operator."); } SqlIdentifier left = (SqlIdentifier)((SqlBasicCall)sqlNode).getOperands()[0]; SqlIdentifier right = (SqlIdentifier)((SqlBasicCall)sqlNode).getOperands()[1]; String leftTableName = left.names.get(0); String rightTableName = right.names.get(0); String tableCol; if(leftTableName.equalsIgnoreCase(specifyTableName)){ tableCol = left.names.get(1); }else if(rightTableName.equalsIgnoreCase(specifyTableName)){ tableCol = right.names.get(1); }else{ throw new RuntimeException(String.format("side table:%s join condition is wrong", specifyTableName)); } conditionFields.add(tableCol); } return conditionFields; } public static SqlSelect newSelect(SqlSelect selectSelf, String table, String alias, boolean left, boolean newTable) { List<SqlNode> operand = selectSelf.getOperandList(); SqlNodeList keywordList = (SqlNodeList)operand.get(0); SqlNodeList selectList = (SqlNodeList)operand.get(1); SqlNode from = operand.get(2); SqlNode where = operand.get(3); SqlNodeList groupBy = (SqlNodeList)operand.get(4); SqlNode having = operand.get(5); SqlNodeList windowDecls = (SqlNodeList)operand.get(6); SqlNodeList orderBy = (SqlNodeList)operand.get(7); SqlNode offset = operand.get(8); SqlNode fetch = operand.get(9); if (left) { return newSelect(selectSelf.getParserPosition(), keywordList, selectList, ((SqlJoin)from).getLeft(), where, groupBy, having, windowDecls, orderBy, offset, fetch, alias, newTable); } if (newTable) { return newSelect(selectSelf.getParserPosition(), null, creatFullNewSelectList(alias, selectList), createNewFrom(table, alias, from), where, groupBy, having, windowDecls, orderBy, offset, fetch, alias, newTable); } else { return newSelect(selectSelf.getParserPosition(), null, selectList, ((SqlJoin)from).getRight(), where, groupBy, having, windowDecls, orderBy, offset, fetch, alias, newTable); } } private static SqlNode createNewFrom(String table, String alias, SqlNode from) { SqlIdentifier identifierFirst = new SqlIdentifier(table, from.getParserPosition()); SqlIdentifier identifierSecond = new SqlIdentifier(alias, from.getParserPosition()); return new SqlBasicCall(new SqlAsOperator(), new SqlNode[] {identifierFirst, identifierSecond}, from.getParserPosition()); } private static SqlNodeList creatFullNewSelectList(String alias, SqlNodeList selectList) { SqlNodeList newSelectList = new SqlNodeList( selectList.getParserPosition()); List<String> names = new ArrayList<>(2); names.add(alias); names.add(""); newSelectList.add(new SqlIdentifier(names,new SqlParserPos(0,0))); return newSelectList; } private static SqlSelect newSelect(SqlParserPos parserPosition, SqlNodeList keywordList, SqlNodeList selectList, SqlNode fromNode, SqlNode whereNode, SqlNodeList groupByNode, SqlNode havingNode, SqlNodeList windowDeclsList, SqlNodeList orderByList, SqlNode offsetNode, SqlNode fetchNode, String alias, boolean newTable) { SqlNodeList keyword = keywordList; SqlNodeList select = newTable ? changeTableName(selectList,alias): reduce(selectList, alias); SqlNode from = fromNode == null ? null : fromNode; SqlNode where =newTable ? changeTableName(whereNode,alias): reduce(whereNode, alias); SqlNodeList groupBy = newTable ? changeTableName(groupByNode,alias): reduce(groupByNode, alias); SqlNode having = newTable ? changeTableName(havingNode,alias): reduce(havingNode, alias); SqlNodeList windowDecls = newTable ? changeTableName(windowDeclsList,alias): reduce(windowDeclsList, alias); SqlNodeList orderBy = newTable ? changeTableName(orderByList,alias): reduce(orderByList, alias); SqlNode offset = newTable ? changeTableName(offsetNode,alias): reduce(offsetNode, alias); SqlNode fetch = newTable ? changeTableName(fetchNode,alias): reduce(fetchNode, alias); return new SqlSelect(parserPosition, keyword, select, from, where, groupBy, having, windowDecls, orderBy, offset, fetch); } private static SqlNode reduce(SqlNode sqlNode, String alias) { if (sqlNode == null) { return null; } SqlNode cloneNode = sqlNode.clone(sqlNode.getParserPosition()); SqlKind sqlKind = cloneNode.getKind(); switch (sqlKind) { case IDENTIFIER: SqlIdentifier sqlIdentifier = (SqlIdentifier)cloneNode; String tableName = sqlIdentifier.names.get(0); if (tableName.equalsIgnoreCase(alias)) { return sqlIdentifier; } else { return null; } case OR: case AND: SqlBasicCall call = (SqlBasicCall)cloneNode; SqlNode[] nodes = call.getOperands(); List<SqlNode> sqlNodeList = new ArrayList<>(nodes.length); for (int i = 0; i < nodes.length; i++) { SqlNode node = reduce(nodes[i], alias); if (node != null) { sqlNodeList.add(node); } } if (sqlNodeList.size() == 1) { SqlBinaryOperator equal = new SqlBinaryOperator("=", SqlKind.EQUALS, 30, true, ReturnTypes.BOOLEAN_NULLABLE, InferTypes.FIRST_KNOWN, OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED); SqlBasicCall andEqual = new SqlBasicCall(equal, createEqualNodes(sqlKind), new SqlParserPos(0, 0)); sqlNodeList.add(andEqual); return call.getOperator().createCall(call.getFunctionQuantifier(), call.getParserPosition(), sqlNodeList.toArray(new SqlNode[sqlNodeList.size()])); } else if (sqlNodeList.size() > 1) { return call.getOperator().createCall(call.getFunctionQuantifier(), call.getParserPosition(), sqlNodeList.toArray(new SqlNode[sqlNodeList.size()])); } else { return null; } default: if (sqlNode instanceof SqlBasicCall) { SqlBasicCall sqlBasicCall = (SqlBasicCall)cloneNode; SqlNode node = reduce(sqlBasicCall.getOperands()[0], alias); if (node == null) { return null; } else { SqlBasicCall basicCall = (SqlBasicCall)sqlBasicCall.getOperator().createCall(sqlBasicCall.getFunctionQuantifier(), sqlBasicCall.getParserPosition(), Arrays.copyOf(sqlBasicCall.getOperands(), sqlBasicCall.getOperands().length)); basicCall.setOperand(0, node); return basicCall; } } else { throw new UnsupportedOperationException("can't find tableName"); } } } private static SqlNodeList reduce(SqlNodeList sqlNodes, String alias) { if (sqlNodes == null) { return sqlNodes; } SqlNodeList nodes = sqlNodes.clone(new SqlParserPos(0, 0)); List<SqlNode> newNodes = new ArrayList<>(nodes.size()); Iterator<SqlNode> sqlNodeIterable = nodes.iterator(); while (sqlNodeIterable.hasNext()) { SqlNode sqlNode = sqlNodeIterable.next(); sqlNode = reduce(sqlNode, alias); if (sqlNode != null) { newNodes.add(sqlNode); } } if (newNodes.size() > 0) { return new SqlNodeList(newNodes, nodes.getParserPosition()); } else { return null; } } private static SqlNodeList changeTableName(SqlNodeList sqlNodes, String alias) { if (sqlNodes == null) { return sqlNodes; } SqlNodeList nodes = sqlNodes.clone(new SqlParserPos(0, 0)); List<SqlNode> newNodes = new ArrayList<>(nodes.size()); Iterator<SqlNode> sqlNodeIterable = nodes.iterator(); while (sqlNodeIterable.hasNext()) { SqlNode sqlNode = sqlNodeIterable.next(); sqlNode = changeTableName(sqlNode, alias); newNodes.add(sqlNode); } return new SqlNodeList(newNodes, nodes.getParserPosition()); } public static SqlNode changeTableName(SqlNode sqlNode, String alias) { if (sqlNode == null){ return null; } SqlKind sqlKind = sqlNode.getKind(); switch (sqlKind) { case IDENTIFIER: SqlIdentifier sqlIdentifier = new SqlIdentifier( new ArrayList<>(((SqlIdentifier)sqlNode).names.asList()), sqlNode.getParserPosition()); return sqlIdentifier.setName(0, alias); case OR: case AND: SqlBasicCall call = (SqlBasicCall)sqlNode; SqlNode[] nodes = call.getOperands(); List<SqlNode> sqlNodeList = new ArrayList<>(nodes.length); for (int i = 0; i < nodes.length; i++) { SqlNode node = changeTableName(nodes[i], alias); sqlNodeList.add(node); } return call.getOperator().createCall(call.getFunctionQuantifier(), call.getParserPosition(), sqlNodeList.toArray(new SqlNode[sqlNodeList.size()])); default: if (sqlNode instanceof SqlBasicCall) { SqlBasicCall sqlBasicCall = (SqlBasicCall)sqlNode; SqlNode node = changeTableName(sqlBasicCall.getOperands()[0], alias); SqlBasicCall basicCall = (SqlBasicCall)sqlBasicCall.getOperator().createCall(sqlBasicCall.getFunctionQuantifier(), sqlBasicCall.getParserPosition(), Arrays.copyOf(sqlBasicCall.getOperands(), sqlBasicCall.getOperands().length)); basicCall.setOperand(0, node); return basicCall; } else { throw new UnsupportedOperationException("don't support " +sqlNode); } } } public static SqlNode[] createEqualNodes(SqlKind sqlKind) { SqlNode[] nodes = new SqlNode[2]; if (SqlKind.AND == sqlKind) { nodes[0] = SqlLiteral.createExactNumeric("1", new SqlParserPos(0, 0)); nodes[1] = SqlLiteral.createExactNumeric("1", new SqlParserPos(0, 0)); } else { nodes[0] = SqlLiteral.createExactNumeric("0", new SqlParserPos(0, 0)); nodes[1] = SqlLiteral.createExactNumeric("1", new SqlParserPos(0, 0)); } return nodes; } }