/* * Copyright © 2019 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.servicetalk.concurrent.api.internal; import io.servicetalk.concurrent.PublisherSource.Subscriber; import io.servicetalk.concurrent.PublisherSource.Subscription; import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.TestPublisherSubscriber; import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.rules.Timeout; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Random; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; import static io.servicetalk.concurrent.internal.TerminalNotification.complete; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Runtime.getRuntime; import static java.lang.System.arraycopy; import static java.nio.charset.StandardCharsets.ISO_8859_1; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.fail; public class ConnectablePayloadWriterTest { private static final Logger LOGGER = LoggerFactory.getLogger(ConnectablePayloadWriterTest.class); @Rule public final Timeout timeout = new ServiceTalkTestTimeout(); @Rule public final ExpectedException expectedException = ExpectedException.none(); private final TestPublisherSubscriber<String> subscriber = new TestPublisherSubscriber<>(); private ConnectablePayloadWriter<String> cpw; private ExecutorService executorService; @Before public void setUp() { cpw = new ConnectablePayloadWriter<>(); executorService = Executors.newCachedThreadPool(); } @After public void teardown() { executorService.shutdown(); } @Test public void subscribeDeliverDataSynchronously() throws Exception { AtomicReference<Future<?>> futureRef = new AtomicReference<>(); toSource(cpw.connect().afterOnSubscribe(subscription -> { subscriber.request(1); // request from the TestPublisherSubscriber! // We want to increase the chance that the writer thread has to wait for the Subscriber to become // available, instead of waiting for the requestN demand. CyclicBarrier barrier = new CyclicBarrier(2); futureRef.compareAndSet(null, executorService.submit(toRunnable(() -> { barrier.await(); cpw.write("foo"); cpw.flush(); cpw.close(); }))); try { barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { throw new RuntimeException(e); } })).subscribe(subscriber); Future<?> f = futureRef.get(); assertNotNull(f); f.get(); assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void subscribeCloseSynchronously() throws Exception { AtomicReference<Future<?>> futureRef = new AtomicReference<>(); toSource(cpw.connect().afterOnSubscribe(subscription -> { // We want to increase the chance that the writer thread has to wait for the Subscriber to become // available, instead of waiting for the requestN demand. CyclicBarrier barrier = new CyclicBarrier(2); futureRef.compareAndSet(null, executorService.submit(toRunnable(() -> { barrier.await(); cpw.close(); }))); try { barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { throw new RuntimeException(e); } })).subscribe(subscriber); Future<?> f = futureRef.get(); assertNotNull(f); f.get(); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void writeAfterCloseShouldThrow() throws IOException { cpw.close(); expectedException.expect(IOException.class); cpw.write("foo"); // Make sure the Subscription thread isn't blocked. subscriber.request(1); subscriber.cancel(); } @Test public void multipleWriteAfterCloseShouldThrow() throws Exception { Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); cpw.write("bar"); cpw.flush(); })); toSource(cpw.connect()).subscribe(subscriber); subscriber.request(2); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); // Make sure the Subscription thread isn't blocked. subscriber.request(1); subscriber.cancel(); } @Test public void connectMultipleWriteAfterCloseShouldThrow() throws Exception { toSource(cpw.connect()).subscribe(subscriber); subscriber.request(2); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); cpw.write("bar"); cpw.flush(); })); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); // Make sure the Subscription thread isn't blocked. subscriber.request(1); subscriber.cancel(); } @Test public void cancelUnblocksWrite() throws Exception { CyclicBarrier afterFlushBarrier = new CyclicBarrier(2); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); afterFlushBarrier.await(); cpw.write("bar"); cpw.flush(); })); toSource(cpw.connect()).subscribe(subscriber); subscriber.request(1); afterFlushBarrier.await(); subscriber.cancel(); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); cpw.close(); // should be idempotent // Make sure the Subscription thread isn't blocked. subscriber.request(1); subscriber.cancel(); } @Test public void connectCancelUnblocksWrite() throws Exception { toSource(cpw.connect()).subscribe(subscriber); subscriber.cancel(); Future<?> f = executorService.submit(toRunnable(() -> cpw.write("foo"))); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeItems(), is(empty())); assertThat(subscriber.takeTerminal(), is(complete())); cpw.close(); // should be idempotent // Make sure the Subscription thread isn't blocked. subscriber.request(1); subscriber.cancel(); } @Test public void closeShouldBeIdempotent() throws Exception { Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); })); toSource(cpw.connect()).subscribe(subscriber); subscriber.request(1); f.get(); assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); cpw.close(); // should be idempotent } @Test public void closeShouldBeIdempotentWhenNotSubscribed() throws IOException { cpw.connect(); cpw.close(); cpw.close(); // should be idempotent } @Test public void multipleConnectWithInvalidRequestnShouldFailConnect() throws Exception { CountDownLatch onSubscribe = new CountDownLatch(1); CountDownLatch onComplete = new CountDownLatch(1); AtomicReference<Throwable> errorRef = new AtomicReference<>(); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { s.request(-1); onSubscribe.countDown(); } @Override public void onNext(final String str) { } @Override public void onError(final Throwable t) { errorRef.set(t); } @Override public void onComplete() { onComplete.countDown(); } }); cpw.close(); onSubscribe.await(); assertThat(errorRef.get(), instanceOf(IllegalArgumentException.class)); toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeError(), is(instanceOf(IllegalStateException.class))); assertThat(onComplete.getCount(), equalTo(1L)); } @Test public void multipleConnectWhileEmittingShouldFailConnect() throws Exception { CountDownLatch onNext = new CountDownLatch(1); CountDownLatch onComplete = new CountDownLatch(1); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { s.request(1); } @Override public void onNext(final String str) { onNext.countDown(); } @Override public void onError(final Throwable t) { } @Override public void onComplete() { onComplete.countDown(); } }); cpw.write("foo"); cpw.flush(); cpw.close(); onNext.await(); toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeError(), is(instanceOf(IllegalStateException.class))); onComplete.await(); } @Test public void multipleConnectWhileSubscribedShouldFailConnect() throws Exception { CountDownLatch onSubscribe = new CountDownLatch(1); CountDownLatch onComplete = new CountDownLatch(1); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { onSubscribe.countDown(); } @Override public void onNext(final String str) { } @Override public void onError(final Throwable t) { } @Override public void onComplete() { onComplete.countDown(); } }); cpw.close(); onSubscribe.await(); toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeError(), is(instanceOf(IllegalStateException.class))); onComplete.await(); } @Test public void multipleConnectWhileSubscriberFailedShouldFailConnect() throws Exception { CountDownLatch onError = new CountDownLatch(1); CountDownLatch onComplete = new CountDownLatch(1); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { s.request(1); } @Override public void onNext(final String str) { throw DELIBERATE_EXCEPTION; } @Override public void onError(final Throwable t) { onError.countDown(); } @Override public void onComplete() { onComplete.countDown(); } }); try { cpw.write("foo"); fail(); } catch (RuntimeException cause) { assertSame(DELIBERATE_EXCEPTION, cause); } try { cpw.flush(); fail(); } catch (IOException ignored) { // expected } cpw.close(); onError.await(); toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeError(), is(instanceOf(IllegalStateException.class))); assertThat(onComplete.getCount(), equalTo(1L)); } @Test public void writeFlushCloseConnectSubscribeRequest() throws Exception { Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); })); toSource(cpw.connect()).subscribe(subscriber); subscriber.request(1); f.get(); assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void connectSubscribeRequestWriteFlushClose() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(1); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); })); f.get(); assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void connectSubscribeWriteFlushCloseRequest() throws Exception { toSource(cpw.connect()).subscribe(subscriber); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.close(); })); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(1); f.get(); assertThat(subscriber.takeItems(), contains("foo")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void requestWriteSingleWriteSingleFlushClose() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(2); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.write("bar"); cpw.flush(); cpw.close(); })); f.get(); assertThat(subscriber.takeItems(), contains("foo", "bar")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void requestWriteSingleFlushWriteSingleFlushClose() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(2); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.write("bar"); cpw.flush(); cpw.close(); })); f.get(); assertThat(subscriber.takeItems(), contains("foo", "bar")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void writeSingleFlushWriteSingleFlushRequestClose() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(1); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); cpw.write("bar"); cpw.flush(); cpw.close(); })); subscriber.request(1); f.get(); assertThat(subscriber.takeItems(), contains("foo", "bar")); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void invalidRequestN() throws IOException { AtomicReference<Throwable> failure = new AtomicReference<>(); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { s.request(-1); } @Override public void onNext(final String str) { failure.set(new AssertionError("onNext received for illegal request-n")); } @Override public void onError(final Throwable t) { failure.set(t); } @Override public void onComplete() { failure.set(new AssertionError("onComplete received for illegal request-n")); } }); cpw.close(); assertThat("Unexpected failure", failure.get(), is(instanceOf(IllegalArgumentException.class))); } @Test public void onNextThrows() throws IOException { AtomicReference<Throwable> failure = new AtomicReference<>(); toSource(cpw.connect()).subscribe(new Subscriber<String>() { @Override public void onSubscribe(final Subscription s) { s.request(1); } @Override public void onNext(final String str) { throw DELIBERATE_EXCEPTION; } @Override public void onError(final Throwable t) { failure.set(t); } @Override public void onComplete() { failure.set(new AssertionError("onComplete received when onNext threw.")); } }); try { cpw.write("foo"); fail(); } catch (RuntimeException cause) { assertSame(DELIBERATE_EXCEPTION, cause); } cpw.close(); assertThat("Unexpected failure", failure.get(), is(DELIBERATE_EXCEPTION)); } @Test public void cancelCloses() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.cancel(); Future<?> f = executorService.submit(toRunnable(() -> cpw.write("foo"))); expectedException.expect(ExecutionException.class); expectedException.expectCause(is(instanceOf(RuntimeException.class))); f.get(); } @Test public void cancelCloseAfterWrite() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(1); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); })); f.get(); assertThat(subscriber.takeItems(), contains("foo")); subscriber.cancel(); expectedException.expect(is(instanceOf(IOException.class))); cpw.write("foo"); } @Test public void requestNegativeWrite() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); subscriber.request(-1); Future<?> f = executorService.submit(toRunnable(() -> { cpw.write("foo"); cpw.flush(); })); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeError(), is(instanceOf(IllegalArgumentException.class))); } @Test public void writeRequestNegative() throws Exception { toSource(cpw.connect()).subscribe(subscriber); assertThat(subscriber.takeItems(), is(empty())); CyclicBarrier cb = new CyclicBarrier(2); Future<?> f = executorService.submit(toRunnable(() -> { cb.await(); cpw.write("foo"); cpw.flush(); })); cb.await(); subscriber.request(-1); try { f.get(); fail(); } catch (ExecutionException e) { verifyCheckedRunnableException(e, IOException.class); } assertThat(subscriber.takeError(), is(instanceOf(IllegalArgumentException.class))); } @Test public void closeNoWrite() throws Exception { CyclicBarrier cb = new CyclicBarrier(2); Future<?> f = executorService.submit(toRunnable(() -> { cb.await(); cpw.close(); })); final Publisher<String> connect = cpw.connect(); cb.await(); toSource(connect).subscribe(subscriber); subscriber.request(1); f.get(); assertThat(subscriber.takeItems(), is(empty())); assertThat(subscriber.takeTerminal(), is(complete())); } @Test public void multiThreadedProducerConsumer() throws Exception { final Random r = new Random(); final long seed = r.nextLong(); // capture seed to have repeatable tests r.setSeed(seed); // 3% of heap or max of 100 MiB final int dataSize = (int) min(getRuntime().maxMemory() * 0.03, 100 * 1024 * 1024); LOGGER.info("Test seed = {} – data size = {}", seed, dataSize); final AtomicReference<Throwable> error = new AtomicReference<>(); final byte[] data = new byte[dataSize]; final byte[] received = new byte[dataSize]; for (int i = 0; i < dataSize; ++i) { // Single character encoding data[i] = (byte) (r.nextInt((Byte.MAX_VALUE - 1) + 1) + 1); } final Publisher<String> pub = cpw.connect(); final Thread producerThread = new Thread(() -> { int writeIndex = 0; try { while (writeIndex < dataSize) { // write at most 25% of remaining bytes final int length = (int) max(1, r.nextInt(dataSize - (writeIndex - 1)) * 0.25); LOGGER.debug("Writing {} bytes - writeIndex = {}", length, writeIndex); cpw.write(new String(data, writeIndex, length, ISO_8859_1)); writeIndex += length; if (r.nextDouble() < 0.4) { LOGGER.debug("Flushing - writeIndex = {}", writeIndex); cpw.flush(); } } LOGGER.debug("Closing - writeIndex = {}", writeIndex); cpw.close(); } catch (Throwable t) { error.compareAndSet(null, t); } }); final Thread consumerThread = new Thread(() -> { try { final CountDownLatch consumerDone = new CountDownLatch(1); toSource(pub).subscribe(new Subscriber<String>() { @Nullable private Subscription sub; private int writeIndex; @Override public void onSubscribe(final Subscription s) { sub = s; sub.request(1); } @Override public void onNext(final String str) { LOGGER.debug("Reading {} bytes - writeIndex = {}", str.length(), writeIndex); byte[] bytes = str.getBytes(ISO_8859_1); arraycopy(bytes, 0, received, writeIndex, bytes.length); writeIndex += bytes.length; assert sub != null : "Subscription can not be null in onNext."; sub.request(1); } @Override public void onError(final Throwable t) { error.compareAndSet(null, t); consumerDone.countDown(); } @Override public void onComplete() { consumerDone.countDown(); } }); consumerDone.await(); } catch (Throwable t) { error.compareAndSet(null, t); } }); producerThread.start(); consumerThread.start(); // make sure both threads exit producerThread.join(); consumerThread.join(); // provides visibility for received from consumerThread assertNull(error.get()); assertArrayEquals(data, received); // assertThat() times out } static Runnable toRunnable(CheckedRunnable runnable) { return () -> { try { runnable.doWork(); } catch (Exception e) { throw new RuntimeException(e); } }; } @FunctionalInterface interface CheckedRunnable { void doWork() throws Exception; } static void verifyCheckedRunnableException(ExecutionException e, Class<? extends Throwable> clazz) { assertThat(e.getCause(), is(instanceOf(RuntimeException.class))); // this is from toRunnable assertThat(e.getCause().getCause(), is(instanceOf(clazz))); } }