package io.mycat.route.parser.druid.impl;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement.ValuesClause;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import io.mycat.backend.mysql.nio.handler.FetchStoreNodeOfChildTableHandler;
import io.mycat.backend.mysql.nio.handler.JDBCFetchStoreNodeOfChildTableHandler;
import io.mycat.config.model.SchemaConfig;
import io.mycat.config.model.TableConfig;
import io.mycat.route.RouteResultset;
import io.mycat.route.RouteResultsetNode;
import io.mycat.route.function.AbstractPartitionAlgorithm;
import io.mycat.route.function.SlotFunction;
import io.mycat.route.parser.druid.MycatSchemaStatVisitor;
import io.mycat.route.parser.druid.RouteCalculateUnit;
import io.mycat.route.parser.util.ParseUtil;
import io.mycat.route.util.RouterUtil;
import io.mycat.server.parser.ServerParse;
import io.mycat.util.StringUtil;

import java.sql.SQLNonTransientException;
import java.sql.SQLSyntaxErrorException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class DruidInsertParser extends DefaultDruidParser {
	@Override
	public void visitorParse(RouteResultset rrs, SQLStatement stmt, MycatSchemaStatVisitor visitor) throws SQLNonTransientException {
		
	}
	
	/**
	 * 考虑因素:isChildTable、批量、是否分片
	 */
	@Override
	public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt) throws SQLNonTransientException {
		MySqlInsertStatement insert = (MySqlInsertStatement)stmt;
		String tableName = StringUtil.removeBackquote(insert.getTableName().getSimpleName()).toUpperCase();

		ctx.addTable(tableName);
		if(RouterUtil.isNoSharding(schema,tableName)) {//整个schema都不分库或者该表不拆分
			RouterUtil.routeForTableMeta(rrs, schema, tableName, rrs.getStatement());
			rrs.setFinishedRoute(true);
			return;
		}

		TableConfig tc = schema.getTables().get(tableName);
		if(tc == null) {
			String msg = "can't find table define in schema "
					+ tableName + " schema:" + schema.getName();
			LOGGER.warn(msg);
			throw new SQLNonTransientException(msg);
		} else {
			//childTable的insert直接在解析过程中完成路由
			if (tc.isChildTable()) {
				parserChildTable(schema, rrs, tableName, insert);
				return;
			}
			
			String partitionColumn = tc.getPartitionColumn();
			
			if(partitionColumn != null) {//分片表
				//拆分表必须给出column list,否则无法寻找分片字段的值
				if(insert.getColumns() == null || insert.getColumns().size() == 0) {
					throw new SQLSyntaxErrorException("partition table, insert must provide ColumnList");
				}
				
				//批量insert
				if(isMultiInsert(insert)) {
//					String msg = "multi insert not provided" ;
//					LOGGER.warn(msg);
//					throw new SQLNonTransientException(msg);
					parserBatchInsert(schema, rrs, partitionColumn, tableName, insert);
				} else {
					parserSingleInsert(schema, rrs, partitionColumn, tableName, insert);
				}
				
			}
		}
	}
	
	/**
	 * 寻找joinKey的索引
	 * @param columns
	 * @param joinKey
	 * @return -1表示没找到,>=0表示找到了
	 */
	private int getJoinKeyIndex(List<SQLExpr> columns, String joinKey) {
		for(int i = 0; i < columns.size(); i++) {
			String col = StringUtil.removeBackquote(columns.get(i).toString()).toUpperCase();
			if(col.equals(joinKey)) {
				return i;
			}
		}
		return -1;
	}
	
	/**
	 * 是否为批量插入:insert into ...values (),()...或 insert into ...select.....
	 * @param insertStmt
	 * @return
	 */
	private boolean isMultiInsert(MySqlInsertStatement insertStmt) {
		return (insertStmt.getValuesList() != null && insertStmt.getValuesList().size() > 1) || insertStmt.getQuery() != null;
	}
	
	private RouteResultset parserChildTable(SchemaConfig schema, RouteResultset rrs,
			String tableName, MySqlInsertStatement insertStmt) throws SQLNonTransientException {
		TableConfig tc = schema.getTables().get(tableName);
		
		String joinKey = tc.getJoinKey();
		int joinKeyIndex = getJoinKeyIndex(insertStmt.getColumns(), joinKey);
		if(joinKeyIndex == -1) {
			String inf = "joinKey not provided :" + tc.getJoinKey()+ "," + insertStmt;
			LOGGER.warn(inf);
			throw new SQLNonTransientException(inf);
		}
		if(isMultiInsert(insertStmt)) {
			String msg = "ChildTable multi insert not provided" ;
			LOGGER.warn(msg);
			throw new SQLNonTransientException(msg);
		}
		
		String joinKeyVal = insertStmt.getValues().getValues().get(joinKeyIndex).toString();

		
		String sql = insertStmt.toString();
		
		// try to route by ER parent partion key
		RouteResultset theRrs = RouterUtil.routeByERParentKey(null,schema, ServerParse.INSERT,sql, rrs, tc,joinKeyVal);
		if (theRrs != null) {
			rrs.setFinishedRoute(true);
			return theRrs;
		}

		// route by sql query root parent's datanode
		String findRootTBSql = tc.getLocateRTableKeySql().toLowerCase() + joinKeyVal;
		if (LOGGER.isDebugEnabled()) {
			LOGGER.debug("find root parent's node sql "+ findRootTBSql);
		}

		String dn = null;
		if (tc.getRootParent().getFetchStoreNodeByJdbc()) {
			JDBCFetchStoreNodeOfChildTableHandler jdbcFetchHandler = new JDBCFetchStoreNodeOfChildTableHandler();
			dn = jdbcFetchHandler.execute(schema.getName(),findRootTBSql, tc.getRootParent().getDataNodes());
		} else {
			FetchStoreNodeOfChildTableHandler FetchHandler = new FetchStoreNodeOfChildTableHandler();
			FetchHandler.execute(schema.getName(),findRootTBSql, tc.getRootParent().getDataNodes());
		}

		if (dn == null) {
			throw new SQLNonTransientException("can't find (root) parent sharding node for sql:"+ sql);
		}
		if (LOGGER.isDebugEnabled()) {
			LOGGER.debug("found partion node for child table to insert "+ dn + " sql :" + sql);
		}
		return RouterUtil.routeToSingleNode(rrs, dn, sql);
	}
	
	/**
	 * 单条insert(非批量)
	 * @param schema
	 * @param rrs
	 * @param partitionColumn
	 * @param tableName
	 * @param insertStmt
	 * @throws SQLNonTransientException
	 */
	private void parserSingleInsert(SchemaConfig schema, RouteResultset rrs, String partitionColumn,
			String tableName, MySqlInsertStatement insertStmt) throws SQLNonTransientException {
		boolean isFound = false;
		for(int i = 0; i < insertStmt.getColumns().size(); i++) {
			if(partitionColumn.equalsIgnoreCase(StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString()))) {//找到分片字段
				isFound = true;
				String column = StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString());

				String shardingValue = StringUtil.removeBackquote(getShardingValue(insertStmt.getValues().getValues().get(i)));
				insertStmt.getValues().getValues().set(i,new SQLCharExpr(shardingValue));
				ctx.setSql(insertStmt.toString());

				RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit();
				routeCalculateUnit.addShardingExpr(tableName, column, shardingValue);
				ctx.addRouteCalculateUnit(routeCalculateUnit);
				//mycat是单分片键,找到了就返回
				break;
			}
		}
		if(!isFound) {//分片表的
			String msg = "bad insert sql (sharding column:"+ partitionColumn + " not provided," + insertStmt;
			LOGGER.warn(msg);
			throw new SQLNonTransientException(msg);
		}
		// insert into .... on duplicateKey 
		//such as :INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE b=VALUES(b); 
		//INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE c=c+1;
		if(insertStmt.getDuplicateKeyUpdate() != null) {
			List<SQLExpr> updateList = insertStmt.getDuplicateKeyUpdate();
			for(SQLExpr expr : updateList) {
				SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr)expr;
				String column = StringUtil.removeBackquote(opExpr.getLeft().toString().toUpperCase());
				if(column.equals(partitionColumn)) {
					String msg = "Sharding column can't be updated: " + tableName + " -> " + partitionColumn;
					LOGGER.warn(msg);
					throw new SQLNonTransientException(msg);
				}
			}
		}
	}
	
	/**
	 * insert into .... select .... 或insert into table() values (),(),....
	 * @param schema
	 * @param rrs
	 * @param insertStmt
	 * @throws SQLNonTransientException
	 */
	private void parserBatchInsert(SchemaConfig schema, RouteResultset rrs, String partitionColumn, 
			String tableName, MySqlInsertStatement insertStmt) throws SQLNonTransientException {
		//insert into table() values (),(),....
		if(insertStmt.getValuesList().size() > 1) {
			//字段列数
			int columnNum = insertStmt.getColumns().size();
			int shardingColIndex = getShardingColIndex(insertStmt, partitionColumn);
			if(shardingColIndex == -1) {
				String msg = "bad insert sql (sharding column:"+ partitionColumn + " not provided," + insertStmt;
				LOGGER.warn(msg);
				throw new SQLNonTransientException(msg);
			} else {
				List<ValuesClause> valueClauseList = insertStmt.getValuesList();
				
				Map<Integer,List<ValuesClause>> nodeValuesMap = new HashMap<Integer,List<ValuesClause>>();
				Map<Integer,Integer> slotsMap = new HashMap<>();
				TableConfig tableConfig = schema.getTables().get(tableName);
				AbstractPartitionAlgorithm algorithm = tableConfig.getRule().getRuleAlgorithm();
				for(ValuesClause valueClause : valueClauseList) {
					if(valueClause.getValues().size() != columnNum) {
						String msg = "bad insert sql columnSize != valueSize:"
					             + columnNum + " != " + valueClause.getValues().size() 
					             + "values:" + valueClause;
						LOGGER.warn(msg);
						throw new SQLNonTransientException(msg);
					}
					SQLExpr expr = valueClause.getValues().get(shardingColIndex);
					String shardingValue = StringUtil.removeBackquote(getShardingValue(expr));
					valueClause.getValues().set(shardingColIndex, new SQLCharExpr(shardingValue));

					Integer nodeIndex = algorithm.calculate(StringUtil.removeBackquote(shardingValue));
					if(algorithm instanceof SlotFunction){
						slotsMap.put(nodeIndex,((SlotFunction) algorithm).slotValue()) ;
					}
					//没找到插入的分片
					if(nodeIndex == null) {
						String msg = "can't find any valid datanode :" + tableName 
								+ " -> " + partitionColumn + " -> " + shardingValue;
						LOGGER.warn(msg);
						throw new SQLNonTransientException(msg);
					}
					if(nodeValuesMap.get(nodeIndex) == null) {
						nodeValuesMap.put(nodeIndex, new ArrayList<ValuesClause>());
					}
					nodeValuesMap.get(nodeIndex).add(valueClause);
				}


				RouteResultsetNode[] nodes = new RouteResultsetNode[nodeValuesMap.size()];
				int count = 0;
				for(Map.Entry<Integer,List<ValuesClause>> node : nodeValuesMap.entrySet()) {
					Integer nodeIndex = node.getKey();
					List<ValuesClause> valuesList = node.getValue();
					insertStmt.setValuesList(valuesList);
					if(tableConfig.isDistTable()) {
						nodes[count] = new RouteResultsetNode(tableConfig.getDataNodes().get(0),
								rrs.getSqlType(),insertStmt.toString());
						if(tableConfig.getDistTables()==null){
							String msg = " sub table not exists for " + nodes[count].getName() + " on " + tableName;
							LOGGER.error("DruidMycatRouteStrategyError " + msg);
							throw new SQLSyntaxErrorException(msg);
						}
						String subTableName = tableConfig.getDistTables().get(nodeIndex);
						
						nodes[count].setSubTableName(subTableName);
						SQLInsertStatement insertStatement = (SQLInsertStatement) insertStmt;
						SQLExprTableSource tableSource = insertStatement.getTableSource();
						//getDisTable 修改表名称
						SQLIdentifierExpr sqlIdentifierExpr = new SQLIdentifierExpr();
						sqlIdentifierExpr.setParent(tableSource.getParent());
						sqlIdentifierExpr.setName(subTableName);
						SQLExprTableSource from2 = new SQLExprTableSource(sqlIdentifierExpr);
						insertStatement.setTableSource(from2);
						nodes[count].setStatement(insertStatement.toString());
					} else {
						nodes[count] = new RouteResultsetNode(tableConfig.getDataNodes().get(nodeIndex),
								rrs.getSqlType(),insertStmt.toString());
					}
					
					if(algorithm instanceof SlotFunction) {
						nodes[count].setSlot(slotsMap.get(nodeIndex));
						nodes[count].setStatement(ParseUtil.changeInsertAddSlot(nodes[count].getStatement(),nodes[count].getSlot()));
					}
					nodes[count++].setSource(rrs);

				}
				rrs.setNodes(nodes);
				rrs.setFinishedRoute(true);

			}
		} else if(insertStmt.getQuery() != null) { // insert into .... select ....
			String msg = "TODO:insert into .... select .... not supported!";
			LOGGER.warn(msg);
			throw new SQLNonTransientException(msg);
		}
	}

	private String getShardingValue(SQLExpr expr) throws SQLNonTransientException {
		String shardingValue = null;
		if(expr instanceof SQLIntegerExpr) {
			SQLIntegerExpr intExpr = (SQLIntegerExpr)expr;
			shardingValue = intExpr.getNumber() + "";
		} else if (expr instanceof SQLCharExpr) {
			SQLCharExpr charExpr = (SQLCharExpr)expr;
			shardingValue = charExpr.getText();
		} else if (expr instanceof SQLMethodInvokeExpr) {
			SQLMethodInvokeExpr methodInvokeExpr = (SQLMethodInvokeExpr)expr;
			try {
				shardingValue = tryInvokeSQLMethod(methodInvokeExpr);
			}catch (Exception e){
				LOGGER.error("",e);
			}
			if (shardingValue == null){
				shardingValue = expr.toString();
			}
		} else {
			shardingValue = expr.toString();
		}
		return shardingValue;
	}
	/**
	 * 寻找拆分字段在 columnList中的索引
	 * @param insertStmt
	 * @param partitionColumn
	 * @return
	 */
	private int getShardingColIndex(MySqlInsertStatement insertStmt,String partitionColumn) {
		int shardingColIndex = -1;
		for(int i = 0; i < insertStmt.getColumns().size(); i++) {
			if(partitionColumn.equalsIgnoreCase(StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString()))) {//找到分片字段
				shardingColIndex = i;
				return shardingColIndex;
			}
		}
		return shardingColIndex;
	}
}