/* * Copyright 2015 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc.testing.integration; import static com.google.common.truth.Truth.assertAbout; import static io.grpc.testing.DeadlineSubject.deadline; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Context; import io.grpc.Context.CancellableContext; import io.grpc.Deadline; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleResponse; import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.MockitoAnnotations; /** * Integration test for various forms of cancellation and deadline propagation. */ @RunWith(JUnit4.class) public class CascadingTest { @Mock TestServiceGrpc.TestServiceImplBase service; private ManagedChannel channel; private Server server; private CountDownLatch observedCancellations; private CountDownLatch receivedCancellations; private TestServiceGrpc.TestServiceBlockingStub blockingStub; private TestServiceGrpc.TestServiceStub asyncStub; private TestServiceGrpc.TestServiceFutureStub futureStub; private ExecutorService otherWork; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); // Use a cached thread pool as we need a thread for each blocked call otherWork = Executors.newCachedThreadPool(); channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build(); blockingStub = TestServiceGrpc.newBlockingStub(channel); asyncStub = TestServiceGrpc.newStub(channel); futureStub = TestServiceGrpc.newFutureStub(channel); } @After public void tearDown() { channel.shutdownNow(); server.shutdownNow(); otherWork.shutdownNow(); } /** * Test {@link Context} cancellation propagates from the first node in the call chain all the way * to the last. */ @Test public void testCascadingCancellationViaOuterContextCancellation() throws Exception { observedCancellations = new CountDownLatch(2); receivedCancellations = new CountDownLatch(3); Future<?> chainReady = startChainingServer(3); CancellableContext context = Context.current().withCancellation(); Future<SimpleResponse> future; Context prevContext = context.attach(); try { future = futureStub.unaryCall(SimpleRequest.getDefaultInstance()); } finally { context.detach(prevContext); } chainReady.get(5, TimeUnit.SECONDS); context.cancel(null); try { future.get(5, TimeUnit.SECONDS); fail("Expected cancellation"); } catch (ExecutionException ex) { Status status = Status.fromThrowable(ex); assertEquals(Status.Code.CANCELLED, status.getCode()); // Should have observed 2 cancellations responses from downstream servers if (!observedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations not observed by clients"); } if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations to be received by servers not observed"); } } } /** * Test that cancellation via call cancellation propagates down the call. */ @Test public void testCascadingCancellationViaRpcCancel() throws Exception { observedCancellations = new CountDownLatch(2); receivedCancellations = new CountDownLatch(3); Future<?> chainReady = startChainingServer(3); Future<SimpleResponse> future = futureStub.unaryCall(SimpleRequest.getDefaultInstance()); chainReady.get(5, TimeUnit.SECONDS); future.cancel(true); assertTrue(future.isCancelled()); if (!observedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations not observed by clients"); } if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations to be received by servers not observed"); } } /** * Test that when RPC cancellation propagates up a call chain, the cancellation of the parent * RPC triggers cancellation of all of its children. */ @Test public void testCascadingCancellationViaLeafFailure() throws Exception { // All nodes (15) except one edge of the tree (4) will be cancelled. observedCancellations = new CountDownLatch(11); receivedCancellations = new CountDownLatch(11); startCallTreeServer(3); try { // Use response size limit to control tree nodeCount. blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build()); fail("Expected abort"); } catch (StatusRuntimeException sre) { // Wait for the workers to finish Status status = sre.getStatus(); // Outermost caller observes ABORTED propagating up from the failing leaf, // The descendant RPCs are cancelled so they receive CANCELLED. assertEquals(Status.Code.ABORTED, status.getCode()); if (!observedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations not observed by clients"); } if (!receivedCancellations.await(5, TimeUnit.SECONDS)) { fail("Expected number of cancellations to be received by servers not observed"); } } } @Test public void testDeadlinePropagation() throws Exception { final AtomicInteger recursionDepthRemaining = new AtomicInteger(3); final SettableFuture<Deadline> finalDeadline = SettableFuture.create(); class DeadlineSaver extends TestServiceGrpc.TestServiceImplBase { @Override public void unaryCall(final SimpleRequest request, final StreamObserver<SimpleResponse> responseObserver) { Context.currentContextExecutor(otherWork).execute(new Runnable() { @Override public void run() { try { if (recursionDepthRemaining.decrementAndGet() == 0) { finalDeadline.set(Context.current().getDeadline()); responseObserver.onNext(SimpleResponse.getDefaultInstance()); } else { responseObserver.onNext(blockingStub.unaryCall(request)); } responseObserver.onCompleted(); } catch (Exception ex) { responseObserver.onError(ex); } } }); } } server = InProcessServerBuilder.forName("channel").executor(otherWork) .addService(new DeadlineSaver()) .build().start(); Deadline initialDeadline = Deadline.after(1, TimeUnit.MINUTES); blockingStub.withDeadline(initialDeadline).unaryCall(SimpleRequest.getDefaultInstance()); assertNotSame(initialDeadline, finalDeadline); // Since deadline is re-calculated at each hop, some variance is acceptable and expected. assertAbout(deadline()) .that(finalDeadline.get()).isWithin(1, TimeUnit.SECONDS).of(initialDeadline); } /** * Create a chain of client to server calls which can be cancelled top down. * * @return a Future that completes when call chain is created */ private Future<?> startChainingServer(final int depthThreshold) throws IOException { final AtomicInteger serversReady = new AtomicInteger(); final SettableFuture<Void> chainReady = SettableFuture.create(); class ChainingService extends TestServiceGrpc.TestServiceImplBase { @Override public void unaryCall(final SimpleRequest request, final StreamObserver<SimpleResponse> responseObserver) { ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(new Runnable() { @Override public void run() { receivedCancellations.countDown(); } }); if (serversReady.incrementAndGet() == depthThreshold) { // Stop recursion chainReady.set(null); return; } Context.currentContextExecutor(otherWork).execute(new Runnable() { @Override public void run() { try { blockingStub.unaryCall(request); } catch (StatusRuntimeException e) { Status status = e.getStatus(); if (status.getCode() == Status.Code.CANCELLED) { observedCancellations.countDown(); } else { responseObserver.onError(e); } } } }); } } server = InProcessServerBuilder.forName("channel").executor(otherWork) .addService(new ChainingService()) .build().start(); return chainReady; } /** * Create a tree of client to server calls where each received call on the server * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree. */ private void startCallTreeServer(int depthThreshold) throws IOException { final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1); server = InProcessServerBuilder.forName("channel").executor(otherWork).addService( ServerInterceptors.intercept(service, new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( final ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { // Respond with the headers but nothing else. call.sendHeaders(new Metadata()); call.request(1); return new ServerCall.Listener<ReqT>() { @Override public void onMessage(final ReqT message) { Messages.SimpleRequest req = (Messages.SimpleRequest) message; if (nodeCount.decrementAndGet() == 0) { // we are in the final leaf node so trigger an ABORT upwards Context.currentContextExecutor(otherWork).execute(new Runnable() { @Override public void run() { call.close(Status.ABORTED, new Metadata()); } }); } else if (req.getResponseSize() != 0) { // We are in a non leaf node so fire off two requests req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build(); for (int i = 0; i < 2; i++) { asyncStub.unaryCall(req, new StreamObserver<Messages.SimpleResponse>() { @Override public void onNext(Messages.SimpleResponse value) { } @Override public void onError(Throwable t) { Status status = Status.fromThrowable(t); if (status.getCode() == Status.Code.CANCELLED) { observedCancellations.countDown(); } // Propagate closure upwards. try { call.close(status, new Metadata()); } catch (IllegalStateException t2) { // Ignore error if already closed. } } @Override public void onCompleted() { } }); } } } @Override public void onCancel() { receivedCancellations.countDown(); } }; } }) ).build(); server.start(); } }