package com.qiniu.android.dns.local;

import com.qiniu.android.dns.DnsException;
import com.qiniu.android.dns.Record;
import com.qiniu.android.dns.util.BitSet;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.IDN;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.HashSet;
//import java.util.logging.Level;
//import java.util.logging.Logger;

/**
 * reference github/rtreffer/minidns.
 */
public final class DnsMessage {

    public static byte[] buildQuery(String domain, int id) {
        ByteArrayOutputStream baos = new ByteArrayOutputStream(512);
        DataOutputStream dos = new DataOutputStream(baos);
        BitSet bits = new BitSet();

//        recursionDesired
        bits.set(8);

        try {
            dos.writeShort((short) id);
            dos.writeShort((short) bits.value());

//        questions count
            dos.writeShort(1);

//      no  answer
            dos.writeShort(0);

//      no nameserverRecords
            dos.writeShort(0);

//      no  additionalResourceRecords
            dos.writeShort(0);

            dos.flush();
            writeQuestion(baos, domain);
        } catch (IOException e) {
            throw new AssertionError(e);
        }

        return baos.toByteArray();
    }

    private static void writeDomain(OutputStream out, String domain) throws IOException {
        for (String s : domain.split("[.\u3002\uFF0E\uFF61]")) {
            byte[] buffer = IDN.toASCII(s).getBytes();
            out.write(buffer.length);
            out.write(buffer, 0, buffer.length); // ?
        }
        out.write(0);
    }

    private static void writeQuestion(OutputStream out, String domain) throws IOException {
        DataOutputStream dos = new DataOutputStream(out);
        writeDomain(out, domain);
//        type A
        dos.writeShort(1);
//        class internet
        dos.writeShort(1);
    }

    public static Record[] parseResponse(byte[] response, int id, String domain) throws IOException {
        ByteArrayInputStream bis = new ByteArrayInputStream(response);
        DataInputStream dis = new DataInputStream(bis);
        int answerId = dis.readUnsignedShort();
        if (answerId != id) {
            throw new DnsException(domain, "the answer id " + answerId + " is not match " + id);
        }
        int header = dis.readUnsignedShort();
        boolean recursionDesired = ((header >> 8) & 1) == 1;
        boolean recursionAvailable = ((header >> 7) & 1) == 1;
        if (!(recursionAvailable && recursionDesired)) {
            throw new DnsException(domain, "the dns server cant support recursion ");

        }

        int questionCount = dis.readUnsignedShort();
        int answerCount = dis.readUnsignedShort();
//        nameserver Count
        dis.readUnsignedShort();
//        additionalResourceRecordCount
        dis.readUnsignedShort();

//         ignore questions
        readQuestions(dis, response, questionCount);

        return readAnswers(dis, response, answerCount);
//        ignore auth
//        ignore additional
    }

    /**
     * Parse a domain name starting at the current offset and moving the input
     * stream pointer past this domain name (even if cross references occure).
     *
     * @param dis  The input stream.
     * @param data The raw data (for cross references).
     * @return The domain name string.
     * @throws IOException Should never happen.
     */
    private static String readName(DataInputStream dis, byte data[])
            throws IOException {
        int c = dis.readUnsignedByte();
        if ((c & 0xc0) == 0xc0) {
            c = ((c & 0x3f) << 8) + dis.readUnsignedByte();
            HashSet<Integer> jumps = new HashSet<Integer>();
            jumps.add(c);
            return readName(data, c, jumps);
        }
        if (c == 0) {
            return "";
        }
        byte b[] = new byte[c];
        dis.readFully(b);
        String s = IDN.toUnicode(new String(b));
        String t = readName(dis, data);
        if (t.length() > 0) {
            s = s + "." + t;
        }
        return s;
    }

    /**
     * Parse a domain name starting at the given offset.
     *
     * @param data   The raw data.
     * @param offset The offset.
     * @param jumps  The list of jumps (by now).
     * @return The parsed domain name.
     * @throws IOException on cycles.
     */
    private static String readName(
            byte data[],
            int offset,
            HashSet<Integer> jumps
    ) throws IOException {
        int c = data[offset] & 0xff;
        if ((c & 0xc0) == 0xc0) {
            c = ((c & 0x3f) << 8) + (data[offset + 1] & 0xff);
            if (jumps.contains(c)) {
                throw new DnsException("", "Cyclic offsets detected.");
            }
            jumps.add(c);
            return readName(data, c, jumps);
        }
        if (c == 0) {
            return "";
        }
        String s = new String(data, offset + 1, c);
        String t = readName(data, offset + 1 + c, jumps);
        if (t.length() > 0) {
            s = s + "." + t;
        }
        return s;
    }

    private static void readQuestions(DataInputStream dis, byte[] data, int count) throws IOException {
        while (count-- > 0) {
            readName(dis, data);
//            type
            dis.readUnsignedShort();
//            class
            dis.readUnsignedShort();
        }
    }

    private static Record[] readAnswers(DataInputStream dis, byte[] data, int count) throws IOException {
        int offset = 0;
        Record[] ret = new Record[count];
        while (count-- > 0) {
            ret[offset++] = readRecord(dis, data);
        }
        return ret;
    }

    private static Record readRecord(DataInputStream dis, byte[] data) throws IOException {
        readName(dis, data);
        int type = dis.readUnsignedShort();
//            class
        dis.readUnsignedShort();

        long ttl = (((long) dis.readUnsignedShort()) << 16) +
                dis.readUnsignedShort();
        int payloadLength = dis.readUnsignedShort();
        String payload = null;
        switch (type) {
            case Record.TYPE_A:
                byte[] ip = new byte[4];
                dis.readFully(ip);
                payload = InetAddress.getByAddress(ip).getHostAddress();
                break;
            case Record.TYPE_CNAME:
                payload = readName(dis, data);
                break;
            default:
                payload = null;
                for (int i = 0; i < payloadLength; i++) {
                    dis.readByte();
                }
                break;
        }
        if (payload == null) {
            throw new UnknownHostException("no record");
        }
        return new Record(payload, type, (int) ttl, System.currentTimeMillis() / 1000);
    }
}