/* * Copyright 2018-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.vault.authentication; import java.time.Duration; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; import org.springframework.vault.authentication.event.AfterLoginEvent; import org.springframework.vault.authentication.event.AfterLoginTokenRenewedEvent; import org.springframework.vault.authentication.event.AfterLoginTokenRevocationEvent; import org.springframework.vault.authentication.event.AuthenticationErrorEvent; import org.springframework.vault.authentication.event.AuthenticationErrorListener; import org.springframework.vault.authentication.event.AuthenticationEvent; import org.springframework.vault.authentication.event.AuthenticationListener; import org.springframework.vault.authentication.event.BeforeLoginTokenRenewedEvent; import org.springframework.vault.authentication.event.BeforeLoginTokenRevocationEvent; import org.springframework.vault.authentication.event.LoginFailedEvent; import org.springframework.vault.authentication.event.LoginTokenExpiredEvent; import org.springframework.vault.support.LeaseStrategy; import org.springframework.vault.support.VaultResponse; import org.springframework.vault.support.VaultToken; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec; import org.springframework.web.reactive.function.client.WebClient.RequestBodyUriSpec; import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec; import org.springframework.web.reactive.function.client.WebClient.RequestHeadersUriSpec; import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; import org.springframework.web.reactive.function.client.WebClientResponseException; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; /** * Unit tests for {@link ReactiveLifecycleAwareSessionManager}. * * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) class ReactiveLifecycleAwareSessionManagerUnitTests { @Mock VaultTokenSupplier tokenSupplier; @Mock TaskScheduler taskScheduler; @Mock WebClient webClient; @Mock RequestHeadersSpec requestHeadersSpec; @Mock RequestHeadersUriSpec requestHeadersUriSpec; @Mock RequestBodyUriSpec requestBodyUriSpec; @Mock RequestBodySpec requestBodySpec; @Mock ResponseSpec responseSpec; @Mock AuthenticationListener listener; @Mock AuthenticationErrorListener errorListener; @Captor ArgumentCaptor<AuthenticationEvent> captor; private ReactiveLifecycleAwareSessionManager sessionManager; @BeforeEach void before() { // POST when(this.webClient.post()).thenReturn(this.requestBodyUriSpec); when(this.requestBodyUriSpec.uri(anyString())).thenReturn(this.requestBodySpec); when(this.requestBodySpec.headers(any())).thenReturn(this.requestBodySpec); when(this.requestBodySpec.retrieve()).thenReturn(this.responseSpec); // GET when(this.webClient.get()).thenReturn(this.requestHeadersUriSpec); when(this.requestHeadersUriSpec.uri(anyString())).thenReturn(this.requestHeadersSpec); when(this.requestHeadersSpec.headers(any())).thenReturn(this.requestHeadersSpec); when(this.requestHeadersSpec.retrieve()).thenReturn(this.responseSpec); this.sessionManager = new ReactiveLifecycleAwareSessionManager(this.tokenSupplier, this.taskScheduler, this.webClient); this.sessionManager.addAuthenticationListener(this.listener); this.sessionManager.addErrorListener(this.errorListener); } @Test void shouldObtainTokenFromClientAuthentication() { mockToken(LoginToken.of("login")); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNext(LoginToken.of("login")) // .verifyComplete(); verify(this.listener).onAuthenticationEvent(any(AfterLoginEvent.class)); } @Test void loginShouldFail() { when(this.tokenSupplier.getVaultToken()).thenReturn(Mono.error(new VaultLoginException("foo"))); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .verifyError(); verifyZeroInteractions(this.listener); verify(this.errorListener).onAuthenticationError(any(LoginFailedEvent.class)); } @Test @SuppressWarnings("unchecked") void shouldSelfLookupToken() { VaultResponse vaultResponse = new VaultResponse(); vaultResponse.setData(Collections.singletonMap("ttl", 100)); mockToken(VaultToken.of("login")); when(this.responseSpec.bodyToMono((Class) any())).thenReturn(Mono.just(vaultResponse)); this.sessionManager.getSessionToken().as(StepVerifier::create).assertNext(it -> { LoginToken sessionToken = (LoginToken) it; assertThat(sessionToken.getLeaseDuration()).isEqualTo(Duration.ofSeconds(100)); }).verifyComplete(); verify(this.webClient.get()).uri("auth/token/lookup-self"); verify(this.listener).onAuthenticationEvent(this.captor.capture()); AfterLoginEvent event = (AfterLoginEvent) this.captor.getValue(); assertThat(event.getSource()).isInstanceOf(LoginToken.class); } @Test @SuppressWarnings("unchecked") void shouldContinueIfSelfLookupFails() { VaultResponse vaultResponse = new VaultResponse(); vaultResponse.setData(Collections.singletonMap("ttl", 100)); mockToken(VaultToken.of("login")); when(this.responseSpec.bodyToMono((Class) any())).thenReturn( Mono.error(new WebClientResponseException("forbidden", 403, "Forbidden", null, null, null))); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .assertNext(it -> { assertThat(it).isExactlyInstanceOf(VaultToken.class); }).verifyComplete(); verify(this.listener).onAuthenticationEvent(any(AfterLoginEvent.class)); verify(this.errorListener).onAuthenticationError(any()); } @Test void tokenRenewalShouldMapException() { mockToken(LoginToken.renewable("foo".toCharArray(), Duration.ofMinutes(1))); when(this.responseSpec.bodyToMono((Class) any())).thenReturn(Mono.error( new WebClientResponseException("Some server error", 500, "Some server error", null, null, null))); AtomicReference<AuthenticationErrorEvent> listener = new AtomicReference<>(); this.sessionManager.addErrorListener(listener::set); this.sessionManager.getVaultToken().as(StepVerifier::create).expectNextCount(1).verifyComplete(); this.sessionManager.renewToken().as(StepVerifier::create).verifyComplete(); assertThat(listener.get().getException()).isInstanceOf(VaultTokenRenewalException.class) .hasCauseInstanceOf(WebClientResponseException.class) .hasMessageContaining("Cannot renew token: Status 500 Some server error"); } @Test void shouldRevokeLoginTokenOnDestroy() { VaultResponse vaultResponse = new VaultResponse(); vaultResponse.setData(Collections.singletonMap("ttl", 100)); mockToken(LoginToken.of("login")); when(this.responseSpec.bodyToMono(String.class)).thenReturn(Mono.just("OK")); this.sessionManager.getVaultToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); this.sessionManager.destroy(); verify(this.webClient.post()).uri("auth/token/revoke-self"); verify(this.listener).onAuthenticationEvent(any(BeforeLoginTokenRevocationEvent.class)); verify(this.listener).onAuthenticationEvent(any(AfterLoginTokenRevocationEvent.class)); } @Test void shouldNotRevokeRegularTokenOnDestroy() { mockToken(VaultToken.of("login")); this.sessionManager.setTokenSelfLookupEnabled(false); this.sessionManager.renewToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); this.sessionManager.destroy(); verify(this.webClient, never()).post(); verify(this.webClient.post(), never()).uri("auth/token/revoke-self"); verify(this.listener).onAuthenticationEvent(any(AfterLoginEvent.class)); verifyNoMoreInteractions(this.listener); } @Test void shouldNotThrowExceptionsOnRevokeErrors() { mockToken(LoginToken.of("login")); when(this.responseSpec.bodyToMono((Class) any())).thenReturn( Mono.error(new WebClientResponseException("forbidden", 403, "Forbidden", null, null, null))); this.sessionManager.renewToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); this.sessionManager.destroy(); verify(this.requestBodyUriSpec).uri("auth/token/revoke-self"); } @Test void shouldScheduleTokenRenewal() { mockToken(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(any(Runnable.class), any(Trigger.class)); } @Test void shouldRunTokenRenewal() { mockToken(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); VaultResponse vaultResponse = new VaultResponse(); Map<String, Object> auth = new HashMap<>(); auth.put("client_token", "login"); auth.put("ttl", 100); vaultResponse.setAuth(auth); when(this.responseSpec.bodyToMono(VaultResponse.class)).thenReturn(Mono.just(vaultResponse)); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); verify(this.webClient).post(); verify(this.webClient.post()).uri("auth/token/renew-self"); verify(this.tokenSupplier, times(1)).getVaultToken(); verify(this.listener).onAuthenticationEvent(any(BeforeLoginTokenRenewedEvent.class)); verify(this.listener).onAuthenticationEvent(any(AfterLoginTokenRenewedEvent.class)); } @Test void shouldReScheduleTokenRenewalAfterSuccessfulRenewal() { mockToken(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))); when(this.responseSpec.bodyToMono(VaultResponse.class)) .thenReturn(Mono.just(fromToken(LoginToken.of("foo".toCharArray(), Duration.ofSeconds(10))))); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); verify(this.taskScheduler, times(2)).schedule(any(Runnable.class), any(Trigger.class)); } @Test void shouldNotScheduleRenewalIfRenewalTtlExceedsThreshold() { mockToken(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))); when(this.responseSpec.bodyToMono(VaultResponse.class)) .thenReturn(Mono.just(fromToken(LoginToken.of("foo".toCharArray(), Duration.ofSeconds(2))))); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); this.sessionManager.getSessionToken().as(StepVerifier::create).expectNextCount(1).verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); verify(this.taskScheduler, times(1)).schedule(any(Runnable.class), any(Trigger.class)); } @Test void shouldReLoginIfRenewalTtlExceedsThreshold() { when(this.tokenSupplier.getVaultToken()).thenReturn( Mono.just(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))), Mono.just(LoginToken.renewable("bar".toCharArray(), Duration.ofSeconds(5)))); when(this.responseSpec.bodyToMono(VaultResponse.class)) .thenReturn(Mono.just(fromToken(LoginToken.of("foo".toCharArray(), Duration.ofSeconds(2))))); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); this.sessionManager.getSessionToken().as(StepVerifier::create) .expectNext(LoginToken.renewable("bar".toCharArray(), Duration.ofSeconds(5))).verifyComplete(); verify(this.tokenSupplier, times(2)).getVaultToken(); verify(this.listener, times(2)).onAuthenticationEvent(any(AfterLoginEvent.class)); verify(this.listener).onAuthenticationEvent(any(LoginTokenExpiredEvent.class)); } @Test void shouldReLoginIfRenewFails() { when(this.tokenSupplier.getVaultToken()).thenReturn( Mono.just(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))), Mono.just(LoginToken.renewable("bar".toCharArray(), Duration.ofSeconds(5)))); when(this.responseSpec.bodyToMono(VaultResponse.class)).thenReturn(Mono.error(new RuntimeException("foo"))); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); this.sessionManager.getSessionToken().as(StepVerifier::create) .expectNext(LoginToken.renewable("bar".toCharArray(), Duration.ofSeconds(5))).verifyComplete(); verify(this.tokenSupplier, times(2)).getVaultToken(); } @Test void shouldRetainTokenAfterRenewalFailure() { when(this.tokenSupplier.getVaultToken()).thenReturn( Mono.just(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))), Mono.just(LoginToken.renewable("bar".toCharArray(), Duration.ofSeconds(5)))); when(this.responseSpec.bodyToMono(VaultResponse.class)).thenReturn(Mono.error(new RuntimeException("foo"))); this.sessionManager.setLeaseStrategy(LeaseStrategy.retainOnError()); ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class); this.sessionManager.getSessionToken() // .as(StepVerifier::create) // .expectNextCount(1) // .verifyComplete(); verify(this.taskScheduler).schedule(runnableCaptor.capture(), any(Trigger.class)); runnableCaptor.getValue().run(); this.sessionManager.getSessionToken().as(StepVerifier::create) .expectNext(LoginToken.renewable("login".toCharArray(), Duration.ofSeconds(5))).verifyComplete(); verify(this.tokenSupplier).getVaultToken(); } private static VaultResponse fromToken(LoginToken loginToken) { Map<String, Object> auth = new HashMap<>(); auth.put("client_token", loginToken.getToken()); auth.put("renewable", loginToken.isRenewable()); auth.put("lease_duration", loginToken.getLeaseDuration().getSeconds()); VaultResponse response = new VaultResponse(); response.setAuth(auth); return response; } private void mockToken(VaultToken token) { when(this.tokenSupplier.getVaultToken()).thenReturn(Mono.just(token)); } }