package org.springframework.data.mybatis.repository.support; import java.util.HashMap; import java.util.List; import java.util.Map; import org.mybatis.spring.SqlSessionTemplate; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** * Convenient super class for MyBatis SqlSession data access objects. It gives you access * to the template which can then be used to execute SQL methods. * * @author JARVIS SONG */ public abstract class SqlSessionRepositorySupport { public static final char DOT = '.'; private final SqlSessionTemplate sqlSession; protected SqlSessionRepositorySupport(SqlSessionTemplate sqlSessionTemplate) { Assert.notNull(sqlSessionTemplate, "SqlSessionTemplate must not be null!"); this.sqlSession = sqlSessionTemplate; } public SqlSessionTemplate getSqlSession() { return sqlSession; } /** * Sub class can override this method. * @return Namespace */ protected abstract String getNamespace(); /** * get the mapper statement include namespace. * @param partStatement partStatement * @return Statement */ protected String getStatement(String partStatement) { return getNamespace() + DOT + partStatement; } /** * select one query. * @param statement statement * @param <T> entity class * @return result */ protected <T> T selectOne(String statement) { return getSqlSession().selectOne(getStatement(statement)); } protected <T> T selectOne(String statement, Object parameter) { return getSqlSession().selectOne(getStatement(statement), parameter); } protected <T> List<T> selectList(String statement) { return getSqlSession().selectList(getStatement(statement)); } protected <T> List<T> selectList(String statement, Object parameter) { return getSqlSession().selectList(getStatement(statement), parameter); } protected int insert(String statement, Object parameter) { return getSqlSession().insert(getStatement(statement), parameter); } protected int update(String statement, Object parameter) { return getSqlSession().update(getStatement(statement), parameter); } protected int delete(String statement) { return getSqlSession().delete(getStatement(statement)); } protected int delete(String statement, Object parameter) { return getSqlSession().delete(getStatement(statement), parameter); } /** * Calculate total mount. * @return if return -1 means can not judge ,need count from database. */ protected <X> long calculateTotal(Pageable pager, List<X> result) { if (pager.hasPrevious()) { if (CollectionUtils.isEmpty(result)) { return -1; } if (result.size() == pager.getPageSize()) { return -1; } return (pager.getPageNumber() - 1) * pager.getPageSize() + result.size(); } if (result.size() < pager.getPageSize()) { return result.size(); } return -1; } protected <X, Y> Page<X> findByPager(Pageable pager, String selectStatement, String countStatement, Y condition, Map<String, Object> otherParams) { Map<String, Object> params = new HashMap<>(); params.put("__offset", pager.getOffset()); params.put("__pageSize", pager.getPageSize()); params.put("__offsetEnd", pager.getOffset() + pager.getPageSize()); if (condition instanceof Sort && ((Sort) condition).isSorted()) { params.put("__sort", condition); } else if (null != pager && null != pager.getSort() && pager.getSort().isSorted()) { params.put("__sort", pager.getSort()); } params.put("__condition", condition); if (!CollectionUtils.isEmpty(otherParams)) { params.putAll(otherParams); } List<X> result = selectList(selectStatement, params); long total = calculateTotal(pager, result); if (total < 0) { total = selectOne(countStatement, params); } return new PageImpl<>(result, pager, total); } protected <X, Y> Page<X> findByPager(Pageable pager, String selectStatement, String countStatement, Y condition) { return findByPager(pager, selectStatement, countStatement, condition, null); } }