/* * Copyright 2019 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.client; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.linecorp.armeria.client.endpoint.dns.TestDnsServer.newAddressRecord; import static io.netty.handler.codec.dns.DnsRecordType.A; import static io.netty.handler.codec.dns.DnsRecordType.AAAA; import static io.netty.handler.codec.dns.DnsSection.ANSWER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import com.google.common.collect.ImmutableMap; import com.linecorp.armeria.client.RefreshingAddressResolver.CacheEntry; import com.linecorp.armeria.client.endpoint.dns.TestDnsServer; import com.linecorp.armeria.client.retry.Backoff; import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.EventLoop; import io.netty.handler.codec.dns.DatagramDnsQuery; import io.netty.handler.codec.dns.DefaultDnsQuestion; import io.netty.handler.codec.dns.DefaultDnsResponse; import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsSection; import io.netty.resolver.AddressResolver; import io.netty.resolver.ResolvedAddressTypes; import io.netty.resolver.dns.DnsNameResolverTimeoutException; import io.netty.resolver.dns.DnsServerAddressStreamProvider; import io.netty.resolver.dns.DnsServerAddresses; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; class RefreshingAddressResolverTest { @RegisterExtension static final EventLoopExtension eventLoopExtension = new EventLoopExtension(); @Test void resolve() throws Exception { try (TestDnsServer server = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1")), new DefaultDnsQuestion("bar.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("bar.com.", "1.2.3.4")))) ) { final EventLoop eventLoop = eventLoopExtension.get(); try (RefreshingAddressResolverGroup group = builder(server).build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> foo = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().untilAsserted(() -> assertThat(foo.isSuccess()).isTrue()); InetSocketAddress addr = foo.getNow(); assertThat(addr.getAddress().getHostAddress()).isEqualTo("1.1.1.1"); assertThat(addr.getPort()).isEqualTo(36462); final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); assertThat(cache.size()).isOne(); final Future<InetSocketAddress> bar = resolver.resolve( InetSocketAddress.createUnresolved("bar.com", 36462)); await().untilAsserted(() -> assertThat(bar.isSuccess()).isTrue()); addr = bar.getNow(); assertThat(addr.getAddress().getHostAddress()).isEqualTo("1.2.3.4"); assertThat(addr.getPort()).isEqualTo(36462); assertThat(cache.size()).isEqualTo(2); final Future<InetSocketAddress> foo1 = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 80)); addr = foo1.getNow(); assertThat(addr.getAddress().getHostAddress()).isEqualTo("1.1.1.1"); assertThat(addr.getPort()).isEqualTo(80); assertThat(cache.size()).isEqualTo(2); final List<InetAddress> addresses = cache.values() .stream() .map(future -> future.join().address()) .collect(toImmutableList()); assertThat(addresses).containsExactlyInAnyOrder( InetAddress.getByAddress("foo.com", new byte[] { 1, 1, 1, 1 }), InetAddress.getByAddress("bar.com", new byte[] { 1, 2, 3, 4 })); } } } @Test void removedWhenNoCacheHit() throws Exception { try (TestDnsServer server = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1", 1)))) ) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final long start = System.nanoTime(); final Future<InetSocketAddress> foo = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().untilAsserted(() -> assertThat(foo.isSuccess()).isTrue()); assertThat(foo.getNow().getAddress().getHostAddress()).isEqualTo("1.1.1.1"); final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); await().until(cache::isEmpty); assertThat(System.nanoTime() - start).isGreaterThanOrEqualTo( (long) (TimeUnit.SECONDS.toNanos(1) * 0.9)); } } } @Test void refreshing() throws Exception { try (TestDnsServer server = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("baz.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("baz.com.", "1.1.1.1", 1)))) ) { final EventLoop eventLoop = eventLoopExtension.get(); try (RefreshingAddressResolverGroup group = builder(server).build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final long start = System.nanoTime(); final Future<InetSocketAddress> foo = resolver.resolve( InetSocketAddress.createUnresolved("baz.com", 36462)); await().untilAsserted(() -> assertThat(foo.isSuccess()).isTrue()); assertThat(foo.getNow().getAddress().getHostAddress()).isEqualTo("1.1.1.1"); final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); assertThat(cache.size()).isOne(); assertThat(cache.get("baz.com").join().address()).isEqualTo( InetAddress.getByAddress("baz.com", new byte[] { 1, 1, 1, 1 })); // Resolve one more to increase cache hits. resolver.resolve(InetSocketAddress.createUnresolved("baz.com", 36462)); server.setResponses(ImmutableMap.of( new DefaultDnsQuestion("baz.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("baz.com.", "2.2.2.2")))); await().until(() -> { final CompletableFuture<CacheEntry> future = cache.get("baz.com"); return future != null && future.join().address().equals( InetAddress.getByAddress("baz.com", new byte[] { 2, 2, 2, 2 })); }); assertThat(System.nanoTime() - start).isGreaterThanOrEqualTo( (long) (TimeUnit.SECONDS.toNanos(1) * 0.9)); // ttl 2 seconds * buffer (90%) } } } @Test void removedWhenExceedingBackoffMaxAttempts() throws Exception { try (TestDnsServer server = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1", 1)))) ) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server); builder.refreshBackoff(Backoff.ofDefault().withMaxAttempts(1)); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final long start = System.nanoTime(); final Future<InetSocketAddress> foo = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().untilAsserted(() -> assertThat(foo.isSuccess()).isTrue()); assertThat(foo.getNow().getAddress().getHostAddress()).isEqualTo("1.1.1.1"); server.setResponses(ImmutableMap.of()); // Schedule resolve() every 500 millis to keep cache hits greater than 0. for (int i = 1; i <= 4; i++) { eventLoop.schedule( () -> resolver.resolve(InetSocketAddress.createUnresolved("foo.com", 36462)), 500 * i, TimeUnit.MILLISECONDS); } final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); await().until(cache::isEmpty); assertThat(System.nanoTime() - start).isGreaterThanOrEqualTo( (long) (TimeUnit.SECONDS.toNanos(1) * 0.9)); // buffer (90%) final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isDone); assertThat(future.cause()).isInstanceOf(UnknownHostException.class); } } } @Test void cacheClearWhenClosed() throws Exception { try (TestDnsServer server = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1")))) ) { final EventLoop eventLoop = eventLoopExtension.get(); final RefreshingAddressResolverGroup group = builder(server).build(eventLoop); final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> foo = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().untilAsserted(() -> assertThat(foo.isSuccess()).isTrue()); assertThat(foo.getNow().getAddress().getHostAddress()).isEqualTo("1.1.1.1"); final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); assertThat(cache.size()).isEqualTo(1); final CacheEntry cacheEntry = cache.get("foo.com").join(); group.close(); await().until(() -> { final ScheduledFuture<?> future = cacheEntry.refreshFuture; return future != null && future.isCancelled(); }); assertThat(cache).isEmpty(); } } @Test void negativeTtl() { // TimeoutHandler times out only the first query. try (TestDnsServer server = new TestDnsServer(ImmutableMap.of(), new TimeoutHandler())) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server).negativeTtl(60).queryTimeoutMillis(1000); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isDone); final Throwable cause = future.cause(); assertThat(cause).isInstanceOfAny(UnknownHostException.class, DnsTimeoutException.class); if (cause instanceof UnknownHostException) { assertThat(cause).hasCauseInstanceOf(DnsNameResolverTimeoutException.class); } // Because it's timed out, the result is not cached. final ConcurrentMap<String, CompletableFuture<CacheEntry>> cache = group.cache(); assertThat(cache.size()).isZero(); final Future<InetSocketAddress> future2 = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future2::isDone); assertThat(future2.cause()).isInstanceOf(UnknownHostException.class) .hasNoCause(); // Because it is NXDOMAIN, the result is cached. assertThat(cache.size()).isOne(); } } } @Test void timeout() { try (TestDnsServer server1 = new TestDnsServer(ImmutableMap.of(), new TimeoutHandler()); TestDnsServer server2 = new TestDnsServer(ImmutableMap.of(), new TimeoutHandler()); TestDnsServer server3 = new TestDnsServer(ImmutableMap.of(), new TimeoutHandler()); TestDnsServer server4 = new TestDnsServer(ImmutableMap.of(), new TimeoutHandler()); TestDnsServer server5 = new TestDnsServer(ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1"))))) { final DnsResolverGroupBuilder builder = builder(server1, server2, server3, server4, server5) .negativeTtl(60) .queryTimeoutMillis(1000); final ClientFactory factory = ClientFactory.builder().addressResolverGroupFactory(builder::build).build(); final WebClient client = WebClient.builder("http://foo.com").factory(factory).build(); assertThatThrownBy(() -> client.get("/").aggregate().join()) .hasCauseInstanceOf(UnprocessedRequestException.class) .hasRootCauseExactlyInstanceOf(DnsTimeoutException.class); } } @Test void returnDnsQuestionsWhenAllQueryTimeout() throws Exception { try (TestDnsServer server1 = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler()); TestDnsServer server2 = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler())) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server1, server2) .queryTimeoutMillis(1000) .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isDone); assertThat(future.cause()).isInstanceOf(DnsTimeoutException.class); } } } @Test void returnPartialDnsQuestions() throws Exception { // Returns IPv6 correctly and make IPv4 timeout. try (TestDnsServer server = new TestDnsServer( ImmutableMap.of( new DefaultDnsQuestion("foo.com.", AAAA), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "::1", 1)))) ) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server) .queryTimeoutMillis(1000) .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isDone); assertThat(future.getNow().getAddress().getHostAddress()).isEqualTo("0:0:0:0:0:0:0:1"); } } } @Test void preferredOrderIpv4() throws Exception { try (TestDnsServer server = new TestDnsServer( ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1")), new DefaultDnsQuestion("foo.com.", AAAA), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "::1", 1))), new DelayHandler(A)) ) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server) .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isSuccess); assertThat(future.getNow().getAddress().getHostAddress()).isEqualTo("1.1.1.1"); } } } @Test void preferredOrderIpv6() throws Exception { try (TestDnsServer server = new TestDnsServer( ImmutableMap.of( new DefaultDnsQuestion("foo.com.", A), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "1.1.1.1")), new DefaultDnsQuestion("foo.com.", AAAA), new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "::1", 1))), new DelayHandler(AAAA)) ) { final EventLoop eventLoop = eventLoopExtension.get(); final DnsResolverGroupBuilder builder = builder(server) .resolvedAddressTypes(ResolvedAddressTypes.IPV6_PREFERRED); try (RefreshingAddressResolverGroup group = builder.build(eventLoop)) { final AddressResolver<InetSocketAddress> resolver = group.getResolver(eventLoop); final Future<InetSocketAddress> future = resolver.resolve( InetSocketAddress.createUnresolved("foo.com", 36462)); await().until(future::isSuccess); assertThat(future.getNow().getAddress().getHostAddress()).isEqualTo("0:0:0:0:0:0:0:1"); } } } private static DnsResolverGroupBuilder builder(TestDnsServer... servers) { final DnsServerAddressStreamProvider dnsServerAddressStreamProvider = hostname -> DnsServerAddresses.sequential( Stream.of(servers).map(TestDnsServer::addr).collect(toImmutableList())).stream(); return new DnsResolverGroupBuilder() .dnsServerAddressStreamProvider(dnsServerAddressStreamProvider) .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) .traceEnabled(false); } private static class TimeoutHandler extends ChannelInboundHandlerAdapter { private int recordACount; @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof DatagramDnsQuery) { final DatagramDnsQuery dnsQuery = (DatagramDnsQuery) msg; final DnsRecord dnsRecord = dnsQuery.recordAt(DnsSection.QUESTION, 0); if (dnsRecord.type() == A && recordACount++ == 0) { // Just release the msg and return so that the client request is timed out. ReferenceCountUtil.safeRelease(msg); return; } } super.channelRead(ctx, msg); } } private static class AlwaysTimeoutHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof DatagramDnsQuery) { // Just release the msg and return so that the client request is timed out. ReferenceCountUtil.safeRelease(msg); return; } super.channelRead(ctx, msg); } } private static class DelayHandler extends ChannelInboundHandlerAdapter { private final DnsRecordType delayType; DelayHandler(DnsRecordType delayType) { this.delayType = delayType; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof DatagramDnsQuery) { final DatagramDnsQuery dnsQuery = (DatagramDnsQuery) msg; final DnsRecord dnsRecord = dnsQuery.recordAt(DnsSection.QUESTION, 0); if (dnsRecord.type() == delayType) { ctx.executor().schedule(() -> { try { super.channelRead(ctx, msg); } catch (Exception ignore) { } }, 1, TimeUnit.SECONDS); return; } } super.channelRead(ctx, msg); } } }