/*
 * Copyright 2002-2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.messaging.rsocket;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import io.reactivex.Completable;
import io.reactivex.Observable;
import io.reactivex.Single;
import io.rsocket.AbstractRSocket;
import io.rsocket.Payload;
import org.junit.Before;
import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec;
import org.springframework.messaging.rsocket.RSocketRequester.ResponseSpec;
import org.springframework.util.MimeTypeUtils;

import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * Unit tests for {@link DefaultRSocketRequester}.
 *
 * @author Rossen Stoyanchev
 */
public class DefaultRSocketRequesterTests {

	private static final Duration MILLIS_10 = Duration.ofMillis(10);


	private TestRSocket rsocket;

	private RSocketRequester requester;

	private final DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory();


	@Before
	public void setUp() {
		RSocketStrategies strategies = RSocketStrategies.builder()
				.decoder(StringDecoder.allMimeTypes())
				.encoder(CharSequenceEncoder.allMimeTypes())
				.build();
		this.rsocket = new TestRSocket();
		this.requester = RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, strategies);
	}


	@Test
	public void singlePayload() {

		// data(Object)
		testSinglePayload(spec -> spec.data("bodyA"), "bodyA");
		testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA")), "bodyA");
		testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).then()), "");
		testSinglePayload(spec -> spec.data(Single.timer(10, MILLISECONDS).map(l -> "bodyA")), "bodyA");
		testSinglePayload(spec -> spec.data(Completable.complete()), "");

		// data(Publisher<T>, Class<T>)
		testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA"), String.class), "bodyA");
		testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).map(l -> "bodyA"), Object.class), "bodyA");
		testSinglePayload(spec -> spec.data(Mono.delay(MILLIS_10).then(), Void.class), "");
	}

	private void testSinglePayload(Function<RequestSpec, ResponseSpec> mapper, String expectedValue) {
		mapper.apply(this.requester.route("toA")).send().block(Duration.ofSeconds(5));

		assertEquals("fireAndForget", this.rsocket.getSavedMethodName());
		assertEquals("toA", this.rsocket.getSavedPayload().getMetadataUtf8());
		assertEquals(expectedValue, this.rsocket.getSavedPayload().getDataUtf8());
	}

	@Test
	public void multiPayload() {
		String[] values = new String[] {"bodyA", "bodyB", "bodyC"};
		Flux<String> stringFlux = Flux.fromArray(values).delayElements(MILLIS_10);

		// data(Object)
		testMultiPayload(spec -> spec.data(stringFlux), values);
		testMultiPayload(spec -> spec.data(Flux.empty()), "");
		testMultiPayload(spec -> spec.data(Observable.fromArray(values).delay(10, MILLISECONDS)), values);
		testMultiPayload(spec -> spec.data(Observable.empty()), "");

		// data(Publisher<T>, Class<T>)
		testMultiPayload(spec -> spec.data(stringFlux, String.class), values);
		testMultiPayload(spec -> spec.data(stringFlux.cast(Object.class), Object.class), values);
	}

	private void testMultiPayload(Function<RequestSpec, ResponseSpec> mapper, String... expectedValues) {
		this.rsocket.reset();
		mapper.apply(this.requester.route("toA")).retrieveFlux(String.class).blockLast(Duration.ofSeconds(5));

		assertEquals("requestChannel", this.rsocket.getSavedMethodName());
		List<Payload> payloads = this.rsocket.getSavedPayloadFlux().collectList().block(Duration.ofSeconds(5));
		assertNotNull(payloads);

		if (Arrays.equals(new String[] {""}, expectedValues)) {
			assertEquals(1, payloads.size());
			assertEquals("toA", payloads.get(0).getMetadataUtf8());
			assertEquals("", payloads.get(0).getDataUtf8());
		}
		else {
			assertArrayEquals(new String[] {"toA", "", ""},
					payloads.stream().map(Payload::getMetadataUtf8).toArray(String[]::new));
			assertArrayEquals(expectedValues,
					payloads.stream().map(Payload::getDataUtf8).toArray(String[]::new));
		}
	}

	@Test
	public void send() {
		String value = "bodyA";
		this.requester.route("toA").data(value).send().block(Duration.ofSeconds(5));

		assertEquals("fireAndForget", this.rsocket.getSavedMethodName());
		assertEquals("toA", this.rsocket.getSavedPayload().getMetadataUtf8());
		assertEquals("bodyA", this.rsocket.getSavedPayload().getDataUtf8());
	}

	@Test
	public void retrieveMono() {
		String value = "bodyA";
		this.rsocket.setPayloadMonoToReturn(Mono.delay(MILLIS_10).thenReturn(toPayload(value)));
		Mono<String> response = this.requester.route("").data("").retrieveMono(String.class);

		StepVerifier.create(response).expectNext(value).expectComplete().verify(Duration.ofSeconds(5));
		assertEquals("requestResponse", this.rsocket.getSavedMethodName());
	}

	@Test
	public void retrieveMonoVoid() {
		AtomicBoolean consumed = new AtomicBoolean(false);
		Mono<Payload> mono = Mono.delay(MILLIS_10).thenReturn(toPayload("bodyA")).doOnSuccess(p -> consumed.set(true));
		this.rsocket.setPayloadMonoToReturn(mono);
		this.requester.route("").data("").retrieveMono(Void.class).block(Duration.ofSeconds(5));

		assertTrue(consumed.get());
		assertEquals("requestResponse", this.rsocket.getSavedMethodName());
	}

	@Test
	public void retrieveFlux() {
		String[] values = new String[] {"bodyA", "bodyB", "bodyC"};
		this.rsocket.setPayloadFluxToReturn(Flux.fromArray(values).delayElements(MILLIS_10).map(this::toPayload));
		Flux<String> response = this.requester.route("").data("").retrieveFlux(String.class);

		StepVerifier.create(response).expectNext(values).expectComplete().verify(Duration.ofSeconds(5));
		assertEquals("requestStream", this.rsocket.getSavedMethodName());
	}

	@Test
	public void retrieveFluxVoid() {
		AtomicBoolean consumed = new AtomicBoolean(false);
		Flux<Payload> flux = Flux.just("bodyA", "bodyB")
				.delayElements(MILLIS_10).map(this::toPayload).doOnComplete(() -> consumed.set(true));
		this.rsocket.setPayloadFluxToReturn(flux);
		this.requester.route("").data("").retrieveFlux(Void.class).blockLast(Duration.ofSeconds(5));

		assertTrue(consumed.get());
		assertEquals("requestStream", this.rsocket.getSavedMethodName());
	}

	@Test
	public void rejectFluxToMono() {
		try {
			this.requester.route("").data(Flux.just("a", "b")).retrieveMono(String.class);
			fail();
		}
		catch (IllegalArgumentException ex) {
			assertEquals("No RSocket interaction model for Flux request to Mono response.", ex.getMessage());
		}
	}

	private Payload toPayload(String value) {
		return PayloadUtils.createPayload(bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8)));
	}


	private static class TestRSocket extends AbstractRSocket {

		private Mono<Payload> payloadMonoToReturn = Mono.empty();
		private Flux<Payload> payloadFluxToReturn = Flux.empty();

		@Nullable private volatile String savedMethodName;
		@Nullable private volatile Payload savedPayload;
		@Nullable private volatile Flux<Payload> savedPayloadFlux;

		void setPayloadMonoToReturn(Mono<Payload> payloadMonoToReturn) {
			this.payloadMonoToReturn = payloadMonoToReturn;
		}

		void setPayloadFluxToReturn(Flux<Payload> payloadFluxToReturn) {
			this.payloadFluxToReturn = payloadFluxToReturn;
		}

		@Nullable
		String getSavedMethodName() {
			return this.savedMethodName;
		}

		@Nullable
		Payload getSavedPayload() {
			return this.savedPayload;
		}

		@Nullable
		Flux<Payload> getSavedPayloadFlux() {
			return this.savedPayloadFlux;
		}

		public void reset() {
			this.savedMethodName = null;
			this.savedPayload = null;
			this.savedPayloadFlux = null;
		}


		@Override
		public Mono<Void> fireAndForget(Payload payload) {
			this.savedMethodName = "fireAndForget";
			this.savedPayload = payload;
			return Mono.empty();
		}

		@Override
		public Mono<Payload> requestResponse(Payload payload) {
			this.savedMethodName = "requestResponse";
			this.savedPayload = payload;
			return this.payloadMonoToReturn;
		}

		@Override
		public Flux<Payload> requestStream(Payload payload) {
			this.savedMethodName = "requestStream";
			this.savedPayload = payload;
			return this.payloadFluxToReturn;
		}

		@Override
		public Flux<Payload> requestChannel(Publisher<Payload> publisher) {
			this.savedMethodName = "requestChannel";
			this.savedPayloadFlux = Flux.from(publisher);
			return this.payloadFluxToReturn;
		}
	}

}