/* * Copyright (c) 2011-2018, Meituan Dianping. All Rights Reserved. * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.dianping.zebra.shard.router; import java.util.*; import java.util.Map.Entry; import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; import com.dianping.zebra.shard.exception.ShardParseException; import com.dianping.zebra.shard.exception.ShardRouterException; import com.dianping.zebra.shard.merge.MergeContext; import com.dianping.zebra.shard.parser.*; import com.dianping.zebra.shard.router.RouterResult.RouterTarget; import com.dianping.zebra.shard.router.rule.RouterRule; import com.dianping.zebra.shard.router.rule.ShardEvalContext; import com.dianping.zebra.shard.router.rule.ShardEvalResult; import com.dianping.zebra.shard.router.rule.TableShardRule; /** * @author hao.zhu */ public class DefaultShardRouter implements ShardRouter { private SQLRewrite sqlRewrite = new DefaultSQLRewrite(); private RouterRule routerRule; private String defaultDatasource; // 针对不分库不分表的表 private boolean optimizeShardKeyInSql; public DefaultShardRouter(RouterRule routerRule, String defaultDatasource) { this(routerRule, new DefaultSQLRewrite(), defaultDatasource); } public DefaultShardRouter(RouterRule routerRule, SQLRewrite sqlRewrite, String defaultDatasource) { this.routerRule = routerRule; this.sqlRewrite = sqlRewrite; this.defaultDatasource = defaultDatasource; } @Override public RouterResult router(final String sql, List<Object> params) throws ShardRouterException, ShardParseException { SQLParsedResult parsedResult = SQLParser.parseWithCache(sql); boolean optimizeIn = false; RouterResult routerResult = new RouterResult(); SQLHint sqlHint = ((parsedResult instanceof MultiSQLParsedResult) ? ((MultiSQLParsedResult) parsedResult).getSqlHint() : parsedResult.getRouterContext().getSqlhint()); if (sqlHint != null) { routerResult.setConcurrencyLevel(sqlHint.getConcurrencyLevel()); Boolean optimizeInObj = sqlHint.getOptimizeIn(); optimizeIn = (optimizeInObj == null) ? this.optimizeShardKeyInSql : optimizeInObj; } // multi queries if (parsedResult instanceof MultiSQLParsedResult) { routerResult.setOptimizeShardKeyInSql(false); return multiQueriesRouter((MultiSQLParsedResult)parsedResult, params, routerResult); } List<TableShardRule> findShardRules = findShardRules(parsedResult.getRouterContext(), params); if (findShardRules.size() == 1) { return routerOneRule(parsedResult, params, routerResult, findShardRules.get(0), optimizeIn); } else if(findShardRules.size() > 1) { return routerMultiRules(parsedResult, params, routerResult, findShardRules); } else { return routerDefault(parsedResult, params, routerResult, sql); } } private RouterResult multiQueriesRouter(MultiSQLParsedResult multiSQLParsedResult, List<Object> params, RouterResult routerResult) { List<SQLParsedResult> sqlParsedResults = multiSQLParsedResult.getSqlParsedResults(); if (sqlParsedResults != null) { List<RouterResult> routerResults = new ArrayList<RouterResult>(); for (SQLParsedResult sqlParsedResult : sqlParsedResults) { List<TableShardRule> findShardRules = findShardRules(sqlParsedResult.getRouterContext(), params); if (findShardRules.size() != 1) { throw new ShardRouterException("Shard multi queries not support multi table rule or no table rule!"); } RouterResult singleResult = routerOneRule(sqlParsedResult, params, new RouterResult(), findShardRules.get(0), false, true); routerResults.add(singleResult); params = singleResult.getParams(); } Map<String, StringBuilder> mergeSqlMap = new HashMap<String, StringBuilder>(); Map<String, Set<Integer>> variantRefIndexMap = new HashMap<String, Set<Integer>>(); for (RouterResult singleResult : routerResults) { List<RouterTarget> routerTargets = singleResult.getSqls(); if (routerResults != null) { for (RouterTarget target : routerTargets) { String db = target.getDatabaseName(); StringBuilder builder = mergeSqlMap.get(db); if (builder == null) { builder = new StringBuilder(4096); mergeSqlMap.put(db, builder); } Set<Integer> variantRefIndexSet = variantRefIndexMap.get(db); if (variantRefIndexSet == null) { variantRefIndexSet = new HashSet<Integer>(); variantRefIndexMap.put(db, variantRefIndexSet); } List<String> sqls = target.getSqls(); List<Set<Integer>> variantRefIndexList = target.getAllVariantRefIndexList(); if (sqls != null) { Iterator<Set<Integer>> it = (variantRefIndexList == null) ? null : variantRefIndexList.iterator(); for (String sql : sqls) { builder.append(sql).append(';'); if (it != null && it.hasNext()) { variantRefIndexSet.addAll(it.next()); } } } } } } List<RouterTarget> routerTargets = new ArrayList<RouterTarget>(mergeSqlMap.size()); for (Map.Entry<String, StringBuilder> entry : mergeSqlMap.entrySet()) { RouterTarget routerTarget = new RouterTarget(entry.getKey(), Arrays.asList(entry.getValue().toString())); routerTarget.setAllVariantRefIndexList(Arrays.asList(variantRefIndexMap.get(entry.getKey()))); routerTargets.add(routerTarget); } routerResult.setSqls(routerTargets); routerResult.setParams(params); routerResult.setMultiQueries(true); } return routerResult; } // one table private RouterResult routerOneRule(SQLParsedResult parsedResult, List<Object> params, RouterResult routerResult, TableShardRule tableShardRule, boolean optimizeIn) { return routerOneRule(parsedResult, params, routerResult, tableShardRule, optimizeIn, false); } private RouterResult routerOneRule(SQLParsedResult parsedResult, List<Object> params, RouterResult routerResult, TableShardRule tableShardRule, boolean optimizeIn, boolean multiQueries) { ShardEvalResult shardResult = tableShardRule.eval(new ShardEvalContext(parsedResult, params, optimizeIn)); routerResult.setMergeContext(new MergeContext(parsedResult.getMergeContext())); if (shardResult.isBatchInsert()) { routerResult.setBatchInsert(true); buildBatchInsertSqls(shardResult, parsedResult, tableShardRule.getTableName(), routerResult); routerResult.setParams(buildParams(params, routerResult)); } else { if (optimizeIn && shardResult.isOptimizeShardKeyInSql()) { routerResult.setSqls(buildSqls(shardResult.getDbAndTables(), parsedResult, tableShardRule.getTableName(), shardResult.getSkInExprWrapperMap(), shardResult.getShardColumns())); routerResult.setOptimizeShardKeyInSql(true); } else if (multiQueries) { routerResult.setSqls(buildMultiQueriesSqls(shardResult.getDbAndTables(), parsedResult, tableShardRule.getTableName())); } else { routerResult.setSqls(buildSqls(shardResult.getDbAndTables(), parsedResult, tableShardRule.getTableName())); } routerResult.setParams(buildParams(params, routerResult)); } return routerResult; } // multi table for binding table private RouterResult routerMultiRules(SQLParsedResult parsedResult, List<Object> params, RouterResult routerResult, List<TableShardRule> findShardRules) { List<ShardEvalResult> shardResults = new ArrayList<ShardEvalResult>(); ShardEvalContext shardEvalContext = new ShardEvalContext(parsedResult, params); for (TableShardRule tableShardRule : findShardRules) { shardResults.add(tableShardRule.eval(shardEvalContext)); } Map<String, List<Map<String, String>>> dbAndTables = new HashMap<String, List<Map<String, String>>>(); for (ShardEvalResult shardResult : shardResults) { String logicalTable = shardResult.getLogicalTable(); for (Entry<String, Set<String>> entry : shardResult.getDbAndTables().entrySet()) { String db = entry.getKey(); List<Map<String, String>> tableMappingList = dbAndTables.get(db); if (tableMappingList == null) { int size = entry.getValue().size(); tableMappingList = new ArrayList<Map<String, String>>(size); for (int i = 0; i < size; i++) { tableMappingList.add(new HashMap<String, String>()); } dbAndTables.put(db, tableMappingList); } int index = 0; for (String physicalTable : entry.getValue()) { Map<String, String> tableMapping = tableMappingList.get(index++); tableMapping.put(logicalTable, physicalTable); } } } routerResult.setMergeContext(new MergeContext(parsedResult.getMergeContext())); routerResult.setSqls(buildSqls(dbAndTables, parsedResult)); routerResult.setParams(buildParams(params, routerResult)); return routerResult; } // single table default router private RouterResult routerDefault(SQLParsedResult parsedResult, List<Object> params, RouterResult routerResult, String sql) { // add for default strategy List<RouterTarget> routerSqls = new ArrayList<RouterTarget>(); RouterTarget targetedSql = new RouterTarget(defaultDatasource); targetedSql.addSql(sql); routerSqls.add(targetedSql); List<Object> newParams = null; if (params != null) { newParams = new ArrayList<Object>(params); } routerResult.setMergeContext(new MergeContext(parsedResult.getMergeContext())); routerResult.setSqls(routerSqls); routerResult.setParams(newParams); return routerResult; } @Override public boolean validate(String sql) throws ShardParseException, ShardRouterException { return true; } @Override public RouterRule getRouterRule() { return this.routerRule; } private List<TableShardRule> findShardRules(RouterContext context, List<Object> params) throws ShardRouterException { Map<String, TableShardRule> tableShardRules = this.routerRule.getTableShardRules(); List<TableShardRule> tableShardRuleList = new ArrayList<TableShardRule>(); for (String relatedTable : context.getTableSet()) { TableShardRule tableShardRule = tableShardRules.get(relatedTable); if (tableShardRule != null) { tableShardRuleList.add(tableShardRule); } } if(tableShardRuleList.size() > 1) { tableShardRuleList = new ArrayList<TableShardRule>(); for (String relatedTable : context.getTableSet()) { TableShardRule tableShardRule = tableShardRules.get(relatedTable); if (tableShardRule != null) { tableShardRuleList.add(tableShardRule); } } } else if(tableShardRuleList.isEmpty()){ // throw exception if no default jdbcRef if (defaultDatasource == null) { throw new ShardRouterException("No table shard rule can be found for table " + context.getTableSet()); } } return tableShardRuleList; } // build normal sql and multi queries private List<RouterTarget> buildSqls(Map<String, Set<String>> dbAndTables, SQLParsedResult parseResult, String logicTable) { List<RouterTarget> sqls = new ArrayList<RouterTarget>(); for (Entry<String, Set<String>> entry : dbAndTables.entrySet()) { RouterTarget targetedSql = new RouterTarget(entry.getKey()); for (String physicalTable : entry.getValue()) { String _sql = sqlRewrite.rewrite(parseResult, logicTable, physicalTable); String hintComment = parseResult.getRouterContext().getSqlhint().getHintComments(); targetedSql.addSql((hintComment != null ? (hintComment + _sql) : _sql)); } sqls.add(targetedSql); } return sqls; } // build multi queries sql private List<RouterTarget> buildMultiQueriesSqls(Map<String, Set<String>> dbAndTables, SQLParsedResult parseResult, String logicTable) { List<RouterTarget> sqls = new ArrayList<RouterTarget>(); for (Entry<String, Set<String>> entry : dbAndTables.entrySet()) { RouterTarget targetedSql = new RouterTarget(entry.getKey()); for (String physicalTable : entry.getValue()) { Set<Integer> variantRefIndexSet = new HashSet<Integer>(); String _sql = sqlRewrite.rewrite(parseResult, logicTable, physicalTable, variantRefIndexSet); String hintComment = parseResult.getRouterContext().getSqlhint().getHintComments(); targetedSql.addSql((hintComment != null ? (hintComment + _sql) : _sql)); targetedSql.addVariantRefIndexes(variantRefIndexSet); } sqls.add(targetedSql); } return sqls; } // build sql optimize in private List<RouterTarget> buildSqls(Map<String, Set<String>> dbAndTables, SQLParsedResult parseResult, String logicTable, Map<String, Map<String, Set<SQLInExprWrapper>>> skInSqlExprMap, Set<String> shardColumns) { List<RouterTarget> sqls = new ArrayList<RouterTarget>(); for (Entry<String, Set<String>> entry : dbAndTables.entrySet()) { RouterTarget targetedSql = new RouterTarget(entry.getKey()); Map<String, Set<SQLInExprWrapper>> skInMap = skInSqlExprMap.get(entry.getKey()); for (String physicalTable : entry.getValue()) { Set<SQLInExprWrapper> skInSet = null; if (skInMap != null) { skInSet = skInMap.get(physicalTable); } Set<Integer> skInIgnoreParams = new HashSet<Integer>(); String _sql = sqlRewrite.rewrite(parseResult, logicTable, physicalTable, null, new SQLRewriteInParam(skInSet, shardColumns, skInIgnoreParams)); String hintComment = parseResult.getRouterContext().getSqlhint().getHintComments(); if (hintComment != null) { targetedSql.addSql(hintComment + _sql); } else { targetedSql.addSql(_sql); } targetedSql.addSkInIgnoreParams(skInIgnoreParams); } sqls.add(targetedSql); } return sqls; } // binding table private List<RouterTarget> buildSqls(Map<String, List<Map<String, String>>> dbAndTables, SQLParsedResult parseResult) { List<RouterTarget> sqls = new ArrayList<RouterTarget>(); for (Entry<String, List<Map<String, String>>> entry : dbAndTables.entrySet()) { RouterTarget targetedSql = new RouterTarget(entry.getKey()); for (Map<String, String> tables : entry.getValue()) { String _sql = sqlRewrite.rewrite(parseResult.getStmt(), tables); String hintComment = parseResult.getRouterContext().getSqlhint().getForceMasterComment(); if (hintComment != null) { targetedSql.addSql(hintComment + _sql); } else { targetedSql.addSql(_sql); } } sqls.add(targetedSql); } return sqls; } private List<Object> buildParams(List<Object> params, RouterResult rr) { List<Object> newParams = null; if (params != null) { newParams = new ArrayList<Object>(params); MySqlSelectQueryBlock.Limit limitExpr = rr.getMergeContext().getLimitExpr(); if (limitExpr != null) { int offset = Integer.MIN_VALUE; int offsetRefIndex = -1; int limit = Integer.MIN_VALUE; int limitRefIndex = -1; int originOffset = Integer.MIN_VALUE; boolean isSingleTarget = isSingleRouterTarget(rr); if (limitExpr.getOffset() instanceof SQLVariantRefExpr) { SQLVariantRefExpr ref = (SQLVariantRefExpr) limitExpr.getOffset(); offsetRefIndex = ref.getIndex(); offset = (Integer) newParams.get(ref.getIndex()); originOffset = offset; if (!isSingleTarget) { rr.getMergeContext().setOffset(offset); // 不是可拆分limit SQL才重写offset if (!rr.getMergeContext().isOrderBySplitSql()) { offset = 0; } } } if (limitExpr.getRowCount() instanceof SQLVariantRefExpr) { SQLVariantRefExpr ref = (SQLVariantRefExpr) limitExpr.getRowCount(); limitRefIndex = ref.getIndex(); limit = (Integer) newParams.get(ref.getIndex()); if (!isSingleTarget) { rr.getMergeContext().setLimit(limit); if (originOffset != Integer.MIN_VALUE) { limit = originOffset + limit; } } } if (offsetRefIndex > limitRefIndex && offsetRefIndex != -1 && limitRefIndex != -1) { newParams.set(limitRefIndex, offset); newParams.set(offsetRefIndex, limit); } else { if (limitRefIndex != -1) { newParams.set(limitRefIndex, limit); } if (offsetRefIndex != -1) { newParams.set(offsetRefIndex, offset); } } } } return newParams; } // batch insert private void buildBatchInsertSqls(ShardEvalResult shardResult, SQLParsedResult parseResult, String logicTable, RouterResult routerResult) { List<RouterTarget> sqls = new ArrayList<RouterTarget>(); Map<String, Set<String>> dbAndTables = shardResult.getDbAndTables(); Map<String, Map<String, Set<Integer>>> insertClauseIndexMap = shardResult.getInsertClauseIndexMap(); MySqlInsertStatement stmt = (MySqlInsertStatement) parseResult.getStmt(); int clauseColumnSize = stmt.getColumns().size(); for (Entry<String, Set<String>> entry : dbAndTables.entrySet()) { String database = entry.getKey(); RouterTarget targetedSql = new RouterTarget(database); Map<String, Set<Integer>> tbParamIndexMap = insertClauseIndexMap.get(database); int tableIndex = 0; for (String physicalTable : entry.getValue()) { Set<Integer> paramIndexes = null; if (tbParamIndexMap != null) { paramIndexes = tbParamIndexMap.get(physicalTable); } String newSql = sqlRewrite.rewrite(parseResult, logicTable, physicalTable, paramIndexes, null); String hintComment = parseResult.getRouterContext().getSqlhint().getHintComments(); if (hintComment != null) { newSql = hintComment + newSql; } targetedSql.addSql(newSql); targetedSql.addPhysicalTable(physicalTable); List<Integer> mappingList = new ArrayList<Integer>(); if (paramIndexes != null) { for (Integer index : paramIndexes) { for(int i = index * clauseColumnSize; i < (index + 1) * clauseColumnSize; ++i) { mappingList.add(i+1); } } } targetedSql.putParamIndexMappingList(tableIndex, mappingList); tableIndex++; } sqls.add(targetedSql); } routerResult.setSqls(sqls); } public boolean isSingleRouterTarget(RouterResult routerResult) { if (routerResult.getSqls().size() > 1) { return false; } RouterTarget routerTarget = routerResult.getSqls().get(0); if (routerTarget.getSqls().size() > 1) { return false; } return true; } public boolean isOptimizeShardKeyInSql() { return optimizeShardKeyInSql; } public void setOptimizeShardKeyInSql(boolean optimizeShardKeyInSql) { this.optimizeShardKeyInSql = optimizeShardKeyInSql; } }