/* * Copyright 2019 Red Hat, Inc. * * Red Hat licenses this file to you under the Apache License, version 2.0 * (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package io.vertx.ext.web.handler.graphql; import graphql.GraphQL; import graphql.schema.DataFetchingEnvironment; import graphql.schema.GraphQLSchema; import graphql.schema.idl.RuntimeWiring; import graphql.schema.idl.SchemaGenerator; import graphql.schema.idl.SchemaParser; import graphql.schema.idl.TypeDefinitionRegistry; import io.vertx.core.AsyncResult; import io.vertx.core.Handler; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpClientOptions; import io.vertx.core.http.WebSocket; import io.vertx.core.http.WebSocketFrame; import io.vertx.core.json.JsonObject; import io.vertx.core.net.NetClientOptions; import io.vertx.core.net.NetServer; import io.vertx.core.net.NetSocket; import io.vertx.ext.web.WebTestBase; import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.stream.IntStream; import static graphql.schema.idl.RuntimeWiring.newRuntimeWiring; import static io.vertx.core.http.HttpMethod.GET; import static io.vertx.ext.web.handler.graphql.ApolloWSMessageType.COMPLETE; import static io.vertx.ext.web.handler.graphql.ApolloWSMessageType.DATA; /** * @author Rogelio Orts */ public class ApolloWSHandlerTest extends WebTestBase { private static final int MAX_COUNT = 4; private static final int STATIC_COUNT = 5; private ApolloWSOptions apolloWSOptions = new ApolloWSOptions(); private AtomicReference<Subscription> subscriptionRef = new AtomicReference<>(); @Override public void setUp() throws Exception { super.setUp(); GraphQL graphQL = graphQL(); router.route("/graphql").handler(ApolloWSHandler.create(graphQL, apolloWSOptions)); router.route("/graphql").handler(GraphQLHandler.create(graphQL)); } protected GraphQL graphQL() { String schema = vertx.fileSystem().readFileBlocking("counter.graphqls").toString(); SchemaParser schemaParser = new SchemaParser(); TypeDefinitionRegistry typeDefinitionRegistry = schemaParser.parse(schema); RuntimeWiring runtimeWiring = newRuntimeWiring() .type("Query", builder -> builder.dataFetcher("staticCounter", this::getStaticCounter)) .type("Subscription", builder -> builder.dataFetcher("counter", this::getCounter)) .build(); SchemaGenerator schemaGenerator = new SchemaGenerator(); GraphQLSchema graphQLSchema = schemaGenerator.makeExecutableSchema(typeDefinitionRegistry, runtimeWiring); return GraphQL.newGraphQL(graphQLSchema) .build(); } private Map<String, Object> getStaticCounter(DataFetchingEnvironment env) { int count = env.getArgument("num"); Map<String, Object> counter = new HashMap<>(); counter.put("count", count); return counter; } private Publisher<Map<String, Object>> getCounter(DataFetchingEnvironment env) { boolean finite = env.getArgument("finite"); return subscriber -> { Subscription subscription = new Subscription() { @Override public void request(long n) { } @Override public void cancel() { if (!subscriptionRef.compareAndSet(this, null)) { fail(); } } }; if (!subscriptionRef.compareAndSet(null, subscription)) { fail(); } subscriber.onSubscribe(subscription); IntStream.range(0, 5).forEach(num -> { Map<String, Object> counter = new HashMap<>(); counter.put("count", num); subscriber.onNext(counter); }); if (finite) { subscriber.onComplete(); if (!subscriptionRef.compareAndSet(subscription, null)) { fail(); } } }; } @Test public void testSubscriptionWsCall() { waitFor(MAX_COUNT + 2); client.webSocket("/graphql", onSuccess(websocket -> { websocket.exceptionHandler(this::fail); AtomicReference<String> id = new AtomicReference<>(); AtomicInteger counter = new AtomicInteger(); websocket.textMessageHandler(text -> { JsonObject obj = new JsonObject(text); int current = counter.getAndIncrement(); if (current >= 0 && current <= MAX_COUNT) { if (current == 0) { assertTrue(id.compareAndSet(null, obj.getString("id"))); } else { assertEquals(id.get(), obj.getString("id")); } assertEquals(DATA, ApolloWSMessageType.from(obj.getString("type"))); int val = obj.getJsonObject("payload").getJsonObject("data").getJsonObject("counter").getInteger("count"); assertEquals(current, val); complete(); } else if (current == MAX_COUNT + 1) { assertEquals(id.get(), obj.getString("id")); assertEquals(COMPLETE, ApolloWSMessageType.from(obj.getString("type"))); complete(); } else { fail(); } }); JsonObject message = new JsonObject() .put("payload", new JsonObject() .put("query", "subscription Subscription { counter { count } }")) .put("type", "start") .put("id", "1"); websocket.write(message.toBuffer()); })); await(); } @Test public void testQueryWsCall() { testQueryWsCall((webSocket, message) -> webSocket.write(message.toBuffer())); } @Test public void testQueryWsCallMultipleFrames() { testQueryWsCall((webSocket, message) -> { Buffer buffer = message.toBuffer(); int part = buffer.length() / 3; if (part == 0) fail("Cannot perform test"); webSocket.writeFrame(WebSocketFrame.binaryFrame(buffer.getBuffer(0, part), false)); webSocket.writeFrame(WebSocketFrame.continuationFrame(buffer.getBuffer(part, 2 * part), false)); webSocket.writeFrame(WebSocketFrame.continuationFrame(buffer.getBuffer(2 * part, buffer.length()), true)); }); } private void testQueryWsCall(BiConsumer<WebSocket, JsonObject> sender) { waitFor(2); client.webSocket("/graphql", onSuccess(websocket -> { websocket.exceptionHandler(this::fail); AtomicReference<String> id = new AtomicReference<>(); AtomicInteger counter = new AtomicInteger(); websocket.textMessageHandler(text -> { JsonObject obj = new JsonObject(text); int current = counter.getAndIncrement(); if (current == 0) { assertTrue(id.compareAndSet(null, obj.getString("id"))); assertEquals(DATA, ApolloWSMessageType.from(obj.getString("type"))); int val = obj.getJsonObject("payload").getJsonObject("data").getJsonObject("staticCounter").getInteger("count"); assertEquals(STATIC_COUNT, val); complete(); } else if (current == 1) { assertEquals(id.get(), obj.getString("id")); assertEquals(COMPLETE, ApolloWSMessageType.from(obj.getString("type"))); complete(); } else { fail(); } }); JsonObject message = new JsonObject() .put("payload", new JsonObject() .put("query", "query Query { staticCounter { count } }")) .put("type", "start") .put("id", "1"); sender.accept(websocket, message); })); await(); } @Test public void testQueryHttpCall() throws Exception { String query = "query Query { staticCounter { count } }"; GraphQLRequest request = new GraphQLRequest() .setMethod(GET) .setGraphQLQuery(query); request.send(client, onSuccess(body -> { int count = body.getJsonObject("data") .getJsonObject("staticCounter") .getInteger("count"); assertEquals(STATIC_COUNT, count); complete(); })); await(); } @Test public void testWsKeepAlive() { apolloWSOptions.setKeepAlive(100L); client.webSocket("/graphql", onSuccess(websocket -> { websocket.exceptionHandler(this::fail); AtomicInteger counter = new AtomicInteger(0); websocket.textMessageHandler(text -> { try { JsonObject obj = new JsonObject(text); if (counter.getAndIncrement() == 0) { assertEquals(ApolloWSMessageType.CONNECTION_ACK.getText(), obj.getString("type")); } else { assertEquals(ApolloWSMessageType.CONNECTION_KEEP_ALIVE.getText(), obj.getString("type")); complete(); } } catch (Exception e) { fail(e); } }); JsonObject message = new JsonObject() .put("type", "connection_init"); websocket.write(message.toBuffer()); })); await(); } @Test public void testSubscriptionCanceledOnAbruptClose() throws Exception { HttpClientOptions clientOptions = getHttpClientOptions(); int backendPort = clientOptions.getDefaultPort(); int proxyPort = backendPort + 101; Proxy proxy = new Proxy(clientOptions.getDefaultHost(), proxyPort, backendPort); proxy.start(); client.close(); client = vertx.createHttpClient(clientOptions.setDefaultPort(proxyPort)); client.webSocket("/graphql", onSuccess(websocket -> { websocket.exceptionHandler(this::fail); AtomicInteger counter = new AtomicInteger(); websocket.textMessageHandler(text -> { if (counter.getAndIncrement() == MAX_COUNT) { if (subscriptionRef.get() == null) { fail("Expected a live subscription"); } else { proxy.closeAbruptly(onSuccess(v -> { testComplete(); })); } } }); JsonObject message = new JsonObject() .put("payload", new JsonObject() .put("query", "subscription Subscription { counter(finite: false) { count } }")) .put("type", "start") .put("id", "1"); websocket.write(message.toBuffer()); })); await(); assertWaitUntil(() -> subscriptionRef.get() == null); } // We need this proxy to make sure the connection to the backend is reset abruptly // Otherwise the Vert.x HttpClient closes the websocket properly before closing the TCP connection private class Proxy { final String host; final int serverPort, clientPort; volatile NetServer server; volatile NetSocket client; Proxy(String host, int serverPort, int clientPort) { this.host = host; this.serverPort = serverPort; this.clientPort = clientPort; } void start() throws Exception { CountDownLatch latch = new CountDownLatch(1); vertx.createNetServer() .exceptionHandler(Throwable::printStackTrace) .connectHandler(socket -> { socket.pause(); vertx.createNetClient(new NetClientOptions().setSoLinger(0)) .connect(clientPort, host) .onSuccess(client -> { this.client = client; socket.pipeTo(client, v -> socket.close()); client.pipeTo(socket, v -> socket.close()); socket.resume(); }); }) .listen(serverPort, host) .onFailure(cause -> fail(cause)) .onSuccess(server -> { this.server = server; latch.countDown(); }); awaitLatch(latch); } void closeAbruptly(Handler<AsyncResult<Void>> handler) { client.close().onComplete(ar -> { server.close(); handler.handle(ar); }); } } }