package com.kristijangeorgiev.resource.util; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import org.springframework.boot.autoconfigure.security.oauth2.resource.JwtAccessTokenConverterConfigurer; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.token.AccessTokenConverter; import org.springframework.security.oauth2.provider.token.UserAuthenticationConverter; import org.springframework.security.oauth2.provider.token.store.JwtAccessTokenConverter; import org.springframework.stereotype.Component; /** * * @author Kristijan Georgiev * */ @Component public class CustomAccessTokenConverter implements AccessTokenConverter, JwtAccessTokenConverterConfigurer { private boolean includeGrantType; private UserAuthenticationConverter userTokenConverter = new CustomUserAuthenticationConverter(); @Override public void configure(JwtAccessTokenConverter converter) { converter.setAccessTokenConverter(this); } public OAuth2AccessToken extractAccessToken(String value, Map<String, ?> map) { DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(value); Map<String, Object> info = new HashMap<String, Object>(map); info.remove(EXP); info.remove(AUD); info.remove(CLIENT_ID); info.remove(SCOPE); if (map.containsKey(EXP)) token.setExpiration(new Date((Long) map.get(EXP) * 1000L)); if (map.containsKey(JTI)) info.put(JTI, map.get(JTI)); token.setScope(extractScope(map)); token.setAdditionalInformation(info); return token; } @Override public OAuth2Authentication extractAuthentication(Map<String, ?> map) { Set<String> scope = extractScope(map); Map<String, String> parameters = new HashMap<String, String>(); Authentication user = userTokenConverter.extractAuthentication(map); String clientId = (String) map.get(CLIENT_ID); parameters.put(CLIENT_ID, clientId); if (includeGrantType && map.containsKey(GRANT_TYPE)) parameters.put(GRANT_TYPE, (String) map.get(GRANT_TYPE)); Set<String> resourceIds = new LinkedHashSet<String>( map.containsKey(AUD) ? getAudience(map) : Collections.<String>emptySet()); Collection<? extends GrantedAuthority> authorities = null; if (user == null && map.containsKey(AUTHORITIES)) { @SuppressWarnings("unchecked") String[] roles = ((Collection<String>) map.get(AUTHORITIES)).toArray(new String[0]); authorities = AuthorityUtils.createAuthorityList(roles); } OAuth2Request request = new OAuth2Request(parameters, clientId, authorities, true, scope, resourceIds, null, null, null); return new OAuth2Authentication(request, user); } public Map<String, ?> convertAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) { Map<String, Object> response = new HashMap<String, Object>(); OAuth2Request clientToken = authentication.getOAuth2Request(); if (!authentication.isClientOnly()) response.putAll(userTokenConverter.convertUserAuthentication(authentication.getUserAuthentication())); else if (clientToken.getAuthorities() != null && !clientToken.getAuthorities().isEmpty()) response.put(UserAuthenticationConverter.AUTHORITIES, AuthorityUtils.authorityListToSet(clientToken.getAuthorities())); if (token.getScope() != null) response.put(SCOPE, token.getScope()); if (token.getAdditionalInformation().containsKey(JTI)) response.put(JTI, token.getAdditionalInformation().get(JTI)); if (token.getExpiration() != null) response.put(EXP, token.getExpiration().getTime() / 1000); if (includeGrantType && authentication.getOAuth2Request().getGrantType() != null) response.put(GRANT_TYPE, authentication.getOAuth2Request().getGrantType()); response.putAll(token.getAdditionalInformation()); response.put(CLIENT_ID, clientToken.getClientId()); if (clientToken.getResourceIds() != null && !clientToken.getResourceIds().isEmpty()) response.put(AUD, clientToken.getResourceIds()); return response; } private Collection<String> getAudience(Map<String, ?> map) { Object auds = map.get(AUD); if (auds instanceof Collection) { @SuppressWarnings("unchecked") Collection<String> result = (Collection<String>) auds; return result; } return Collections.singleton((String) auds); } private Set<String> extractScope(Map<String, ?> map) { Set<String> scope = Collections.emptySet(); if (map.containsKey(SCOPE)) { Object scopeObj = map.get(SCOPE); if (String.class.isInstance(scopeObj)) scope = new LinkedHashSet<String>(Arrays.asList(String.class.cast(scopeObj).split(" "))); else if (Collection.class.isAssignableFrom(scopeObj.getClass())) { @SuppressWarnings("unchecked") Collection<String> scopeColl = (Collection<String>) scopeObj; scope = new LinkedHashSet<String>(scopeColl); } } return scope; } public void setUserTokenConverter(UserAuthenticationConverter userTokenConverter) { this.userTokenConverter = userTokenConverter; } public void setIncludeGrantType(boolean includeGrantType) { this.includeGrantType = includeGrantType; } }