package io.pivotal.rsocketserver; import io.pivotal.rsocketserver.data.Message; import io.rsocket.SocketAcceptor; import io.rsocket.metadata.WellKnownMimeType; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.rsocket.context.LocalRSocketServerPort; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.rsocket.RSocketRequester; import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.security.rsocket.metadata.SimpleAuthenticationEncoder; import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.time.Duration; import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest public class RSocketClientToServerITest { private static RSocketRequester requester; private static UsernamePasswordMetadata credentials; private static MimeType mimeType; @BeforeAll public static void setupOnce(@Autowired RSocketRequester.Builder builder, @LocalRSocketServerPort Integer port, @Autowired RSocketStrategies strategies) { SocketAcceptor responder = RSocketMessageHandler.responder(strategies, new ClientHandler()); credentials = new UsernamePasswordMetadata("user", "pass"); mimeType = MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()); requester = builder .setupRoute("shell-client") .setupData(UUID.randomUUID().toString()) .setupMetadata(credentials, mimeType) .rsocketStrategies(b -> b.encoder(new SimpleAuthenticationEncoder())) .rsocketConnector(connector -> connector.acceptor(responder)) .connectTcp("localhost", port) .block(); } @Test public void testFireAndForget() { // Send a fire-and-forget message Mono<Void> result = requester .route("fire-and-forget") .data(new Message("TEST", "Fire-And-Forget")) .retrieveMono(Void.class); // Assert that the result is a completed Mono. StepVerifier .create(result) .verifyComplete(); } @Test public void testRequestGetsResponse() { // Send a request message Mono<Message> result = requester .route("request-response") .data(new Message("TEST", "Request")) .retrieveMono(Message.class); // Verify that the response message contains the expected data StepVerifier .create(result) .consumeNextWith(message -> { assertThat(message.getOrigin()).isEqualTo(RSocketController.SERVER); assertThat(message.getInteraction()).isEqualTo(RSocketController.RESPONSE); assertThat(message.getIndex()).isEqualTo(0); }) .verifyComplete(); } @Test public void testRequestGetsStream() { // Send a request message Flux<Message> result = requester .route("stream") .data(new Message("TEST", "Stream")) .retrieveFlux(Message.class); // Verify that the response messages contain the expected data StepVerifier .create(result) .consumeNextWith(message -> { assertThat(message.getOrigin()).isEqualTo(RSocketController.SERVER); assertThat(message.getInteraction()).isEqualTo(RSocketController.STREAM); assertThat(message.getIndex()).isEqualTo(0L); }) .expectNextCount(3) .consumeNextWith(message -> { assertThat(message.getOrigin()).isEqualTo(RSocketController.SERVER); assertThat(message.getInteraction()).isEqualTo(RSocketController.STREAM); assertThat(message.getIndex()).isEqualTo(4L); }) .thenCancel() .verify(); } @Test public void testStreamGetsStream() { Mono<Duration> setting1 = Mono.just(Duration.ofSeconds(6)).delayElement(Duration.ofSeconds(0)); Mono<Duration> setting2 = Mono.just(Duration.ofSeconds(6)).delayElement(Duration.ofSeconds(9)); Flux<Duration> settings = Flux.concat(setting1, setting2); // Send a stream of request messages Flux<Message> result = requester .route("channel") .data(settings) .retrieveFlux(Message.class); // Verify that the response messages contain the expected data StepVerifier .create(result) .consumeNextWith(message -> { assertThat(message.getOrigin()).isEqualTo(RSocketController.SERVER); assertThat(message.getInteraction()).isEqualTo(RSocketController.CHANNEL); assertThat(message.getIndex()).isEqualTo(0L); }) .consumeNextWith(message -> { assertThat(message.getOrigin()).isEqualTo(RSocketController.SERVER); assertThat(message.getInteraction()).isEqualTo(RSocketController.CHANNEL); assertThat(message.getIndex()).isEqualTo(0L); }) .thenCancel() .verify(); } @Test public void testNoMatchingRouteGetsException() { // Send a request with bad route and data Mono<String> result = requester .route("invalid") .data("anything") .retrieveMono(String.class); // Verify that an error is generated StepVerifier.create(result) .expectErrorMessage("No handler for destination 'invalid'") .verify(Duration.ofSeconds(5)); } @AfterAll public static void tearDownOnce() { requester.rsocket().dispose(); } @Slf4j static class ClientHandler { @MessageMapping("client-status") public Flux<String> statusUpdate(String status) { log.info("Connection {}", status); return Flux.interval(Duration.ofSeconds(5)).map(index -> String.valueOf(Runtime.getRuntime().freeMemory())); } } }