package com.revengemission.sso.oauth2.client.config; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.time.Instant; /** * 刷新过期access_token */ @Component public class RefreshExpiredTokenFilter extends OncePerRequestFilter { private static final Logger log = LoggerFactory.getLogger(RefreshExpiredTokenFilter.class); @Value("${oauth2.token.cookie.domain}") String cookieDomain; @Autowired OAuth2AuthorizedClientService oAuth2AuthorizedClientService; private Duration accessTokenExpiresSkew = Duration.ofMillis(10000); private Clock clock = Clock.systemUTC(); @Autowired OAuth2UserService<OAuth2UserRequest, OAuth2User> oAuth2UserService; private DefaultRefreshTokenTokenResponseClient accessTokenResponseClient; public RefreshExpiredTokenFilter() { super(); this.accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { log.debug("entering Refresh ExpiredToken Filter......"); /** * check if authentication is done. */ Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (null != authentication && authentication instanceof OAuth2AuthenticationToken) { OAuth2AuthenticationToken oldOAuth2Token = (OAuth2AuthenticationToken) authentication; OAuth2AuthorizedClient authorizedClient = this.oAuth2AuthorizedClientService .loadAuthorizedClient(oldOAuth2Token.getAuthorizedClientRegistrationId(), oldOAuth2Token.getName()); /** * Check whether token is expired. */ if (authorizedClient != null && isExpired(authorizedClient.getAccessToken())) { try { log.info("===================== Token Expired , trying to refresh"); ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); /* * Call Auth server token endpoint to refresh token. */ OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken()); OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); OAuth2User newOAuth2User = oAuth2UserService.loadUser(new OAuth2UserRequest(clientRegistration, accessTokenResponse.getAccessToken())); /* * Create new authentication(OAuth2AuthenticationToken). */ OAuth2AuthenticationToken updatedUser = new OAuth2AuthenticationToken(newOAuth2User, newOAuth2User.getAuthorities(), oldOAuth2Token.getAuthorizedClientRegistrationId()); /* * Update access_token and refresh_token by saving new authorized client. */ OAuth2AuthorizedClient updatedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, oldOAuth2Token.getName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()); this.oAuth2AuthorizedClientService.saveAuthorizedClient(updatedAuthorizedClient, updatedUser); /* * Set new authentication in SecurityContextHolder. */ SecurityContextHolder.getContext().setAuthentication(updatedUser); Cookie tokenCookie = new Cookie("access_token", accessTokenResponse.getAccessToken().getTokenValue()); tokenCookie.setHttpOnly(true); tokenCookie.setDomain(cookieDomain); tokenCookie.setPath("/"); response.addCookie(tokenCookie); log.info("===================== Refresh Token Done !"); } catch (OAuth2AuthorizationException e) { log.info("Refresh ExpiredToken exception", e); SecurityContextHolder.getContext().setAuthentication(null); } } } log.debug("exit Refresh ExpiredToken Filter......"); filterChain.doFilter(request, response); } private Boolean isExpired(OAuth2AccessToken oAuth2AccessToken) { Instant now = this.clock.instant(); Instant expiresAt = oAuth2AccessToken.getExpiresAt(); return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew)); } }