package com.pugwoo.dbhelper.impl.part; import com.pugwoo.dbhelper.DBHelperInterceptor; import com.pugwoo.dbhelper.exception.NotAllowQueryException; import com.pugwoo.dbhelper.sql.SQLUtils; import com.pugwoo.dbhelper.utils.DOInfoReader; import com.pugwoo.dbhelper.utils.NamedParameterUtils; import com.pugwoo.dbhelper.utils.PreHandleObject; import org.springframework.jdbc.core.PreparedStatementCreator; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.support.GeneratedKeyHolder; import org.springframework.jdbc.support.KeyHolder; import java.lang.reflect.Field; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public abstract class P2_InsertOp extends P1_QueryOp { //////// 拦截器 private void doInterceptBeforeInsert(Object t) { List<Object> list = new ArrayList<Object>(); list.add(t); doInterceptBeforeInsertList(list); } private <T> void doInterceptBeforeInsertList(List<Object> list) { for (DBHelperInterceptor interceptor : interceptors) { boolean isContinue = interceptor.beforeInsert(list); if (!isContinue) { throw new NotAllowQueryException("interceptor class:" + interceptor.getClass()); } } } private void doInterceptAfterInsert(Object t, int rows) { List<Object> list = new ArrayList<Object>(); list.add(t); doInterceptAfterInsertList(list, rows); } private <T> void doInterceptAfterInsertList(final List<Object> list, final int rows) { Runnable runnable = new Runnable() { @Override public void run() { for (int i = interceptors.size() - 1; i >= 0; i--) { interceptors.get(i).afterInsert(list, rows); } } }; if(!executeAfterCommit(runnable)) { runnable.run(); } } //////////// @Override public <T> int insert(T t) { return insert(t, false, true); } @SuppressWarnings("unchecked") @Override public int insert(List<?> list) { if(list == null || list.isEmpty()) { return 0; } doInterceptBeforeInsertList((List<Object>) list); int sum = 0; for(Object obj : list) { sum += insert(obj, false, false); } doInterceptAfterInsertList((List<Object>)list, sum); return sum; } @Override public <T> int insertWithNull(T t) { return insert(t, true, true); } private <T> int insert(T t, boolean isWithNullValue, boolean withInterceptor) { PreHandleObject.preHandleInsert(t); final List<Object> values = new ArrayList<Object>(); if(withInterceptor) { doInterceptBeforeInsert(t); } final String sql = SQLUtils.getInsertSQL(t, values, isWithNullValue); log(sql); final long start = System.currentTimeMillis(); int rows = 0; Field autoIncrementField = DOInfoReader.getAutoIncrementField(t.getClass()); if (autoIncrementField != null) { GeneratedKeyHolder holder = new GeneratedKeyHolder(); rows = jdbcTemplate.update(new PreparedStatementCreator() { @Override public PreparedStatement createPreparedStatement(Connection con) throws SQLException { PreparedStatement statement = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS); for (int i = 0; i < values.size(); i++) { statement.setObject(i + 1, values.get(i)); } return statement; } }, holder); if(rows > 0) { long primaryKey = holder.getKey().longValue(); DOInfoReader.setValue(autoIncrementField, t, primaryKey); } } else { rows = jdbcTemplate.update(sql, values.toArray()); // 此处可以用jdbcTemplate,因为没有in (?)表达式 } long cost = System.currentTimeMillis() - start; logSlow(cost, sql, values); if(withInterceptor) { doInterceptAfterInsert(t, rows); } return rows; } @Override public <T> int insertWhereNotExist(T t, String whereSql, Object... args) { if(whereSql != null) {whereSql = whereSql.replace('\t', ' ');} return insertWhereNotExist(t, false, whereSql, args); } @Override public <T> int insertWithNullWhereNotExist(T t, String whereSql, Object... args) { if(whereSql != null) {whereSql = whereSql.replace('\t', ' ');} return insertWhereNotExist(t, true, whereSql, args); } private <T> int insertWhereNotExist(T t, boolean isWithNullValue, String whereSql, Object... args) { if(whereSql == null || whereSql.isEmpty()) { return insert(t, isWithNullValue, true); } PreHandleObject.preHandleInsert(t); List<Object> values = new ArrayList<Object>(); doInterceptBeforeInsert(t); String sql = SQLUtils.getInsertWhereNotExistSQL(t, values, isWithNullValue, whereSql); if(args != null) { for(Object arg : args) { values.add(arg); } } log(sql); long start = System.currentTimeMillis(); int rows; Field autoIncrementField = DOInfoReader.getAutoIncrementField(t.getClass()); if (autoIncrementField != null) { rows = namedJdbcExecuteUpdateWithReturnId(autoIncrementField, t, sql, values.toArray()); } else { rows = namedJdbcExecuteUpdate(sql, values.toArray()); } long cost = System.currentTimeMillis() - start; logSlow(cost, sql, values); doInterceptAfterInsert(t, rows); return rows; } private int namedJdbcExecuteUpdateWithReturnId(Field autoIncrementField, Object t, String sql, Object... args) { log(sql); long start = System.currentTimeMillis(); List<Object> argsList = new ArrayList<Object>(); // 不要直接用Arrays.asList,它不支持clear方法 if(args != null) { argsList.addAll(Arrays.asList(args)); } KeyHolder keyHolder = new GeneratedKeyHolder(); int rows = namedParameterJdbcTemplate.update( NamedParameterUtils.trans(sql, argsList), new MapSqlParameterSource(NamedParameterUtils.transParam(argsList)), keyHolder); // 因为有in (?) 所以使用namedParameterJdbcTemplate if(rows > 0) { long primaryKey = keyHolder.getKey().longValue(); DOInfoReader.setValue(autoIncrementField, t, primaryKey); } long cost = System.currentTimeMillis() - start; logSlow(cost, sql, argsList); return rows; } }