// Copyright (c) 1999-2004 Brian Wellington ([email protected])

package org.xbill.DNS;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import lombok.extern.slf4j.Slf4j;

/**
 * An implementation of Resolver that sends one query to one server. SimpleResolver handles TCP
 * retries, transaction security (TSIG), and EDNS 0.
 *
 * @see Resolver
 * @see TSIG
 * @see OPTRecord
 * @author Brian Wellington
 */
@Slf4j
public class SimpleResolver implements Resolver {

  /** The default port to send queries to */
  public static final int DEFAULT_PORT = 53;

  /** The default EDNS payload size */
  public static final int DEFAULT_EDNS_PAYLOADSIZE = 1280;

  private InetSocketAddress address;
  private InetSocketAddress localAddress;
  private boolean useTCP, ignoreTruncation;
  private OPTRecord queryOPT = new OPTRecord(DEFAULT_EDNS_PAYLOADSIZE, 0, 0, 0);
  private TSIG tsig;
  private Duration timeoutValue = Duration.ofSeconds(10);

  private static final short DEFAULT_UDPSIZE = 512;

  private static InetSocketAddress defaultResolver =
      new InetSocketAddress(InetAddress.getLoopbackAddress(), DEFAULT_PORT);

  /**
   * Creates a SimpleResolver. The host to query is either found by using ResolverConfig, or the
   * default host is used.
   *
   * @see ResolverConfig
   * @exception UnknownHostException Failure occurred while finding the host
   */
  public SimpleResolver() throws UnknownHostException {
    this((String) null);
  }

  /**
   * Creates a SimpleResolver that will query the specified host
   *
   * @exception UnknownHostException Failure occurred while finding the host
   */
  public SimpleResolver(String hostname) throws UnknownHostException {
    if (hostname == null) {
      address = ResolverConfig.getCurrentConfig().server();
      if (address == null) {
        address = defaultResolver;
      }

      return;
    }

    InetAddress addr;
    if ("0".equals(hostname)) {
      addr = InetAddress.getLoopbackAddress();
    } else {
      addr = InetAddress.getByName(hostname);
    }

    address = new InetSocketAddress(addr, DEFAULT_PORT);
  }

  /** Creates a SimpleResolver that will query the specified host */
  public SimpleResolver(InetSocketAddress host) {
    address = Objects.requireNonNull(host, "host must not be null");
  }

  /** Creates a SimpleResolver that will query the specified host */
  public SimpleResolver(InetAddress host) {
    Objects.requireNonNull(host, "host must not be null");
    address = new InetSocketAddress(host, DEFAULT_PORT);
  }

  /**
   * Gets the destination address associated with this SimpleResolver. Messages sent using this
   * SimpleResolver will be sent to this address.
   *
   * @return The destination address associated with this SimpleResolver.
   */
  public InetSocketAddress getAddress() {
    return address;
  }

  /** Sets the default host (initially localhost) to query */
  public static void setDefaultResolver(InetSocketAddress hostname) {
    defaultResolver = hostname;
  }

  /** Sets the default host (initially localhost) to query */
  public static void setDefaultResolver(String hostname) {
    defaultResolver = new InetSocketAddress(hostname, DEFAULT_PORT);
  }

  /**
   * Gets the port to communicate with on the server
   *
   * @since 3.2
   */
  public int getPort() {
    return address.getPort();
  }

  @Override
  public void setPort(int port) {
    address = new InetSocketAddress(address.getAddress(), port);
  }

  /**
   * Sets the address of the server to communicate with.
   *
   * @param addr The address of the DNS server
   */
  public void setAddress(InetSocketAddress addr) {
    address = addr;
  }

  /**
   * Sets the address of the server to communicate with (on the default DNS port)
   *
   * @param addr The address of the DNS server
   */
  public void setAddress(InetAddress addr) {
    address = new InetSocketAddress(addr, address.getPort());
  }

  /**
   * Sets the local address to bind to when sending messages.
   *
   * @param addr The local address to send messages from.
   */
  public void setLocalAddress(InetSocketAddress addr) {
    localAddress = addr;
  }

  /**
   * Sets the local address to bind to when sending messages. A random port will be used.
   *
   * @param addr The local address to send messages from.
   */
  public void setLocalAddress(InetAddress addr) {
    localAddress = new InetSocketAddress(addr, 0);
  }

  /**
   * Gets whether TCP connections will be used by default
   *
   * @since 3.2
   */
  public boolean getTCP() {
    return useTCP;
  }

  @Override
  public void setTCP(boolean flag) {
    this.useTCP = flag;
  }

  /**
   * Gets whether truncated responses will be ignored.
   *
   * @since 3.2
   */
  public boolean getIgnoreTruncation() {
    return ignoreTruncation;
  }

  @Override
  public void setIgnoreTruncation(boolean flag) {
    this.ignoreTruncation = flag;
  }

  /**
   * Gets the EDNS information on outgoing messages.
   *
   * @return The current {@link OPTRecord} for EDNS or {@code null} if EDNS is disabled.
   * @since 3.2
   */
  public OPTRecord getEDNS() {
    return queryOPT;
  }

  /**
   * Sets the EDNS information on outgoing messages.
   *
   * @param optRecord the {@link OPTRecord} for EDNS options or null to disable EDNS.
   * @see #setEDNS(int, int, int, List)
   * @since 3.2
   */
  public void setEDNS(OPTRecord optRecord) {
    queryOPT = optRecord;
  }

  @Override
  public void setEDNS(int version, int payloadSize, int flags, List<EDNSOption> options) {
    switch (version) {
      case -1:
        queryOPT = null;
        break;

      case 0:
        if (payloadSize == 0) {
          payloadSize = DEFAULT_EDNS_PAYLOADSIZE;
        }
        queryOPT = new OPTRecord(payloadSize, 0, version, flags, options);
        break;

      default:
        throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
    }
  }

  /**
   * Get the TSIG key that messages will be signed with.
   *
   * @return the TSIG signature for outgoing messages or {@code null} if not specified.
   * @since 3.2
   */
  public TSIG getTSIGKey() {
    return tsig;
  }

  @Override
  public void setTSIGKey(TSIG key) {
    tsig = key;
  }

  @Override
  public void setTimeout(Duration timeout) {
    timeoutValue = timeout;
  }

  @Override
  public Duration getTimeout() {
    return timeoutValue;
  }

  private Message parseMessage(byte[] b) throws WireParseException {
    try {
      return new Message(b);
    } catch (IOException e) {
      if (!(e instanceof WireParseException)) {
        e = new WireParseException("Error parsing message");
      }
      throw (WireParseException) e;
    }
  }

  private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
    if (tsig == null) {
      return;
    }
    int error = tsig.verify(response, b, query.getTSIG());
    log.debug("TSIG verify: {}", Rcode.TSIGstring(error));
  }

  private void applyEDNS(Message query) {
    if (queryOPT == null || query.getOPT() != null) {
      return;
    }
    query.addRecord(queryOPT, Section.ADDITIONAL);
  }

  private int maxUDPSize(Message query) {
    OPTRecord opt = query.getOPT();
    if (opt == null) {
      return DEFAULT_UDPSIZE;
    } else {
      return opt.getPayloadSize();
    }
  }

  /**
   * Asynchronously sends a message to a single server, registering a listener to receive a callback
   * on success or exception. Multiple asynchronous lookups can be performed in parallel. Since the
   * callback may be invoked before the function returns, external synchronization is necessary.
   *
   * @param query The query to send
   * @return A future that completes when the response has arrived.
   */
  @Override
  public CompletionStage<Message> sendAsync(Message query) {
    if (query.getHeader().getOpcode() == Opcode.QUERY) {
      Record question = query.getQuestion();
      if (question != null && question.getType() == Type.AXFR) {
        CompletableFuture<Message> f = new CompletableFuture<>();
        CompletableFuture.runAsync(
            () -> {
              try {
                f.complete(sendAXFR(query));
              } catch (IOException e) {
                f.completeExceptionally(e);
              }
            });

        return f;
      }
    }

    Message ednsTsigQuery = query.clone();
    applyEDNS(ednsTsigQuery);
    if (tsig != null) {
      ednsTsigQuery.setTSIG(tsig, Rcode.NOERROR, null);
    }

    return sendAsync(ednsTsigQuery, useTCP);
  }

  CompletableFuture<Message> sendAsync(Message query, boolean forceTcp) {
    int qid = query.getHeader().getID();
    byte[] out = query.toWire(Message.MAXLENGTH);
    int udpSize = maxUDPSize(query);
    boolean tcp = forceTcp || out.length > udpSize;
    log.debug(
        "Sending {}/{}, id={} to {}/{}:{}",
        query.getQuestion().getName(),
        Type.string(query.getQuestion().getType()),
        qid,
        tcp ? "tcp" : "udp",
        address.getAddress().getHostAddress(),
        address.getPort());
    log.trace("Query:\n{}", query);

    CompletableFuture<byte[]> result;
    if (tcp) {
      result = NioTcpClient.sendrecv(localAddress, address, query, out, timeoutValue);
    } else {
      result = NioUdpClient.sendrecv(localAddress, address, out, udpSize, timeoutValue);
    }

    return result.thenComposeAsync(
        in -> {
          CompletableFuture<Message> f = new CompletableFuture<>();

          // Check that the response is long enough.
          if (in.length < Header.LENGTH) {
            f.completeExceptionally(new WireParseException("invalid DNS header - too short"));
            return f;
          }

          // Check that the response ID matches the query ID. We want
          // to check this before actually parsing the message, so that
          // if there's a malformed response that's not ours, it
          // doesn't confuse us.
          int id = ((in[0] & 0xFF) << 8) + (in[1] & 0xFF);
          if (id != qid) {
            f.completeExceptionally(
                new WireParseException("invalid message id: expected " + qid + "; got id " + id));
            return f;
          }

          Message response;
          try {
            response = parseMessage(in);
          } catch (WireParseException e) {
            f.completeExceptionally(e);
            return f;
          }

          // validate name, class and type (rfc5452#section-9.1)
          if (!query.getQuestion().getName().equals(response.getQuestion().getName())) {
            f.completeExceptionally(
                new WireParseException(
                    "invalid name in message: expected "
                        + query.getQuestion().getName()
                        + "; got "
                        + response.getQuestion().getName()));
            return f;
          }

          if (query.getQuestion().getDClass() != response.getQuestion().getDClass()) {
            f.completeExceptionally(
                new WireParseException(
                    "invalid class in message: expected "
                        + DClass.string(query.getQuestion().getDClass())
                        + "; got "
                        + DClass.string(response.getQuestion().getDClass())));
            return f;
          }

          if (query.getQuestion().getType() != response.getQuestion().getType()) {
            f.completeExceptionally(
                new WireParseException(
                    "invalid type in message: expected "
                        + Type.string(query.getQuestion().getType())
                        + "; got "
                        + Type.string(response.getQuestion().getType())));
            return f;
          }

          verifyTSIG(query, response, in, tsig);
          if (!tcp && !ignoreTruncation && response.getHeader().getFlag(Flags.TC)) {
            log.debug("Got truncated response for id {}, retrying via TCP", qid);
            log.trace("Truncated response: {}", response);
            return sendAsync(query, true);
          }

          response.setResolver(this);
          f.complete(response);
          return f;
        });
  }

  private Message sendAXFR(Message query) throws IOException {
    Name qname = query.getQuestion().getName();
    ZoneTransferIn xfrin = ZoneTransferIn.newAXFR(qname, address, tsig);
    xfrin.setTimeout(timeoutValue);
    xfrin.setLocalAddress(localAddress);
    try {
      xfrin.run();
    } catch (ZoneTransferException e) {
      throw new WireParseException(e.getMessage());
    }
    List<Record> records = xfrin.getAXFR();
    Message response = new Message(query.getHeader().getID());
    response.getHeader().setFlag(Flags.AA);
    response.getHeader().setFlag(Flags.QR);
    response.addRecord(query.getQuestion(), Section.QUESTION);
    for (Record record : records) {
      response.addRecord(record, Section.ANSWER);
    }
    return response;
  }

  @Override
  public String toString() {
    return "SimpleResolver [" + address + "]";
  }
}