package cc.mrbird.security.code.img; import cc.mrbird.common.domain.FebsConstant; import cc.mrbird.security.exception.ValidateCodeException; import cc.mrbird.security.properties.FebsSecurityProperties; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.InitializingBean; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.social.connect.web.HttpSessionSessionStrategy; import org.springframework.social.connect.web.SessionStrategy; import org.springframework.util.AntPathMatcher; import org.springframework.web.bind.ServletRequestBindingException; import org.springframework.web.bind.ServletRequestUtils; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import java.util.Set; public class ImageCodeFilter extends OncePerRequestFilter implements InitializingBean { private AuthenticationFailureHandler authenticationFailureHandler; private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy(); private Set<String> url = new HashSet<>(); private FebsSecurityProperties securityProperties; private AntPathMatcher pathMatcher = new AntPathMatcher(); @Override public void afterPropertiesSet() throws ServletException { super.afterPropertiesSet(); String[] configUrls = StringUtils.splitByWholeSeparatorPreserveAllTokens(securityProperties.getCode().getImage().getUrl(), ","); url.addAll(Arrays.asList(configUrls)); url.add(securityProperties.getCode().getImage().getLoginProcessingUrl()); } @Override protected void doFilterInternal(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain) throws ServletException, IOException { boolean match = false; for (String u : url) { if (pathMatcher.match(u, httpServletRequest.getRequestURI())) { match = true; } } if (match) { try { validateCode(new ServletWebRequest(httpServletRequest)); } catch (ValidateCodeException e) { authenticationFailureHandler.onAuthenticationFailure(httpServletRequest, httpServletResponse, e); return; } } filterChain.doFilter(httpServletRequest, httpServletResponse); } private void validateCode(ServletWebRequest servletWebRequest) throws ServletRequestBindingException { ImageCode codeInSession = (ImageCode) sessionStrategy.getAttribute(servletWebRequest, FebsConstant.SESSION_KEY_IMAGE_CODE); String codeInRequest = ServletRequestUtils.getStringParameter(servletWebRequest.getRequest(), "imageCode"); if (StringUtils.isBlank(codeInRequest)) { throw new ValidateCodeException("验证码不能为空!"); } if (codeInSession == null) { throw new ValidateCodeException("验证码不存在,请重新发送!"); } if (codeInSession.isExpire()) { sessionStrategy.removeAttribute(servletWebRequest, FebsConstant.SESSION_KEY_IMAGE_CODE); throw new ValidateCodeException("验证码已过期,请重新发送!"); } if (!StringUtils.equalsIgnoreCase(codeInSession.getCode(), codeInRequest)) { throw new ValidateCodeException("验证码不正确!"); } sessionStrategy.removeAttribute(servletWebRequest, FebsConstant.SESSION_KEY_IMAGE_CODE); } public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { this.authenticationFailureHandler = authenticationFailureHandler; } public void setSecurityProperties(FebsSecurityProperties securityProperties) { this.securityProperties = securityProperties; } }