import time, socket, struct, io from . import utils, types __all__ = [ 'REQUEST', 'RESPONSE', 'DNSError', 'Record', 'DNSMessage', ] REQUEST = 0 RESPONSE = 1 MAXAGE = 3600000 class DNSError(Exception): errors = { 1: 'Format error: bad request', 2: 'Server failure: error occurred', 3: 'Name error: not exist', 4: 'Not implemented: query type not supported', 5: 'Refused: policy reasons' } def __init__(self, code, message=None): message = self.errors.get(code, message) or 'Unknown reply code: %d' % code super().__init__(message) self.code = code class RData: '''Base class of RData''' rtype = -1 @property def type_name(self): return types.get_name(self.rtype).lower() class SOA_RData(RData): '''Start of Authority record''' rtype = types.SOA def __init__(self, *k): ( self.mname, self.rname, self.serial, self.refresh, self.retry, self.expire, self.minimum, ) = k def __repr__(self): return '<%s: %s>' % (self.type_name, self.rname) @classmethod def load(cls, data, l): i, mname = utils.load_message(data, l) i, rname = utils.load_message(data, i) ( serial, refresh, retry, expire, minimum, ) = struct.unpack('!LLLLL', data[i: i + 20]) return i + 20, cls(mname, rname, serial, refresh, retry, expire, minimum) def dump(self, pack_name, offset): mname = pack_name(self.mname, offset + 2) yield mname yield pack_name(self.rname, offset + 2 + len(mname)) yield struct.pack('!LLLLL', self.serial, self.refresh, self.retry, self.expire, self.minimum) class MX_RData(RData): '''Mail exchanger record''' rtype = types.MX def __init__(self, *k): self.preference, self.exchange = k def __repr__(self): return '<%s-%s: %s>' % (self.type_name, self.preference, self.exchange) @classmethod def load(cls, data, l): preference, = struct.unpack('!H', data[l: l + 2]) i, exchange = utils.load_message(data, l + 2) return i, cls(preference, exchange) def dump(self, pack_name, offset): yield struct.pack('!H', self.preference) yield pack_name(self.exchange, offset + 4) class SRV_RData(RData): '''Service record''' rtype = types.SRV def __init__(self, *k): self.priority, self.weight, self.port, self.hostname = k def __repr__(self): return '<%s-%s: %s:%s>' % (self.type_name, self.priority, self.hostname, self.port) @classmethod def load(cls, data, l): priority, weight, port = struct.unpack('!HHH', data[l: l + 6]) i, hostname = utils.load_message(data, l + 6) return i, cls(priority, weight, port, hostname) def dump(self, pack_name, offset): yield struct.pack('!HHH', self.priority, self.weight, self.port) yield pack_name(self.hostname, offset + 8) class NAPTR_RData(RData): '''NAPTR record''' rtype = types.NAPTR def __init__(self, *k): self.order, self.preference, self.flags, self.service, self.regexp, self.replacement = k def __repr__(self): return '<%s-%s-%s: %s %s %s %s>' % (self.type_name, self.order, self.preference, self.flags, self.service, self.regexp, self.replacement) @classmethod def load(cls, data, l): pos = l order, preference = struct.unpack('!HH', data[pos: pos + 4]) pos += 4 length = data[pos] pos += 1 flags = data[pos: pos + length].decode() pos += length length = data[pos] pos += 1 service = data[pos: pos + length].decode() pos += length length = data[pos] pos += 1 regexp = data[pos: pos + length].decode() pos += length i, replacement = utils.load_message(data, pos, lower=False) return i, cls(order, preference, flags, service, regexp, replacement) def dump(self, pack_name, offset): raise NotImplementedError class Record: def __init__(self, q=RESPONSE, name='', qtype=types.ANY, qclass=1, ttl=0, data=None): self.q = q self.name = name self.qtype = qtype self.qclass = qclass if q == RESPONSE: self.ttl = ttl # 0 means item should not be cached self.data = data self.timestamp = int(time.time()) def __repr__(self): if self.q == REQUEST: return f'<Record type=request qtype={types.get_name(self.qtype)} name={self.name}>' else: return f'<Record type=response qtype={types.get_name(self.qtype)} name={self.name} ttl={self.ttl} data={self.data}>' def copy(self, **kw): return Record( q=kw.get('q', self.q), name=kw.get('name', self.name), qtype=kw.get('qtype', self.qtype), qclass=kw.get('qclass', self.qclass), ttl=kw.get('ttl', self.ttl), data=kw.get('data', self.data) ) def parse(self, data, l): l, self.name = utils.load_message(data, l) self.qtype, self.qclass = struct.unpack('!HH', data[l: l + 4]) l += 4 if self.q == RESPONSE: self.timestamp = int(time.time()) self.ttl, dl = struct.unpack('!LH', data[l: l + 6]) l += 6 if self.qtype == types.A: self.data = socket.inet_ntoa(data[l: l + dl]) elif self.qtype == types.AAAA: self.data = socket.inet_ntop(socket.AF_INET6, data[l: l + dl]) elif self.qtype == types.MX: _, self.data = MX_RData.load(data, l) elif self.qtype == types.SRV: _, self.data = SRV_RData.load(data, l) elif self.qtype == types.NAPTR: _, self.data = NAPTR_RData.load(data, l) elif self.qtype == types.SOA: _, self.data = SOA_RData.load(data, l) elif self.qtype in (types.CNAME, types.NS, types.PTR, types.TXT): _, self.data = utils.load_message(data, l) else: self.data = data[l: l + dl] l += dl return l def pack(self, names, offset=0): def pack_name(name, pack_offset): return utils.pack_message(name, names, pack_offset) buf = io.BytesIO() buf.write(utils.pack_message(self.name, names, offset)) buf.write(struct.pack('!HH', self.qtype, self.qclass)) if self.q == RESPONSE: if self.ttl < 0: ttl = MAXAGE else: now = int(time.time()) self.ttl -= now - self.timestamp if self.ttl < 0: self.ttl = 0 self.timestamp = now ttl = self.ttl buf.write(struct.pack('!L', ttl)) if isinstance(self.data, RData): data_str = b''.join(self.data.dump(pack_name, offset + buf.tell())) buf.write(utils.pack_string(data_str, '!H')) elif self.qtype == types.A: buf.write(utils.pack_string(socket.inet_aton(self.data), '!H')) elif self.qtype == types.AAAA: buf.write(utils.pack_string(socket.inet_pton(socket.AF_INET6, self.data), '!H')) elif self.qtype in (types.CNAME, types.NS, types.PTR, types.TXT): name = pack_name(self.data, offset + buf.tell() + 2) buf.write(utils.pack_string(name, '!H')) else: buf.write(utils.pack_string(self.data, '!H')) return buf.getvalue() class DNSMessage: def __init__(self, qr=RESPONSE, qid=0, o=0, aa=0, tc=0, rd=1, ra=0, r=0): self.qr = qr # 0 for request, 1 for response self.qid = qid # id for UDP package self.o = o # opcode: 0 for standard query self.aa = aa # Authoritative Answer self.tc = tc # TrunCation, will be updated on .pack() self.rd = rd # Recursion Desired for request self.ra = ra # Recursion Available for response self.r = r # rcode: 0 for success self.qd = [] self.an = [] # answers self.ns = [] # authority records, aka nameservers self.ar = [] # additional records def __bool__(self): return any(map(len, (self.an, self.ns))) def __getitem__(self, i): return self.an[i] def __iter__(self): return iter(self.an) def __repr__(self): return '<DNSMessage type=%s qid=%d r=%d QD=%s AN=%s NS=%s AR=%s>' % ( self.qr, self.qid, self.r, self.qd, self.an, self.ns, self.ar) def pack(self, size_limit=None): z = 0 names = {} buf = io.BytesIO() buf.seek(12) tc = 0 for group in self.qd, self.an, self.ns, self.ar: if tc: break for rec in group: offset = buf.tell() brec = rec.pack(names, offset) if size_limit is not None and offset + len(brec) > size_limit: tc = 1 break buf.write(brec) self.tc = tc buf.seek(0) buf.write(struct.pack( '!HHHHHH', self.qid, (self.qr << 15) + (self.o << 11) + (self.aa << 10) + (self.tc << 9) + (self.rd << 8) + (self.ra << 7) + (z << 4) + self.r, len(self.qd), len(self.an), len(self.ns), len(self.ar) )) return buf.getvalue() @staticmethod def parse_entry(qr, data, l, n): res = [] for i in range(n): r = Record(qr) l = r.parse(data, l) res.append(r) return l, res @classmethod def parse(cls, data, qid=None): rqid, x, qd, an, ns, ar = struct.unpack('!HHHHHH', data[:12]) if qid is not None and qid != rqid: raise DNSError(-1, 'Transaction ID mismatch') r, x = utils.get_bits(x, 4) # rcode: 0 for no error z, x = utils.get_bits(x, 3) # reserved ra, x = utils.get_bits(x, 1) # recursion available rd, x = utils.get_bits(x, 1) # recursion desired tc, x = utils.get_bits(x, 1) # truncation aa, x = utils.get_bits(x, 1) # authoritative answer o, x = utils.get_bits(x, 4) # opcode qr, x = utils.get_bits(x, 1) # qr: 0 for query and 1 for response ans = cls(qr, rqid, o, aa, tc, rd, ra, r) l, ans.qd = ans.parse_entry(REQUEST, data, 12, qd) l, ans.an = ans.parse_entry(RESPONSE, data, l, an) l, ans.ns = ans.parse_entry(RESPONSE, data, l, ns) l, ans.ar = ans.parse_entry(RESPONSE, data, l, ar) return ans